From 2df4bd45f9f152b65e1ba6276722025f31ec920a Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:05:13 +0000 Subject: [PATCH 01/54] Enable black and isort --- Makefile | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/Makefile b/Makefile index 128b3dc849..d7ca684bc7 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,7 @@ conda-env: build/conda-env.${CONDA_ENV_NAME}.build-stamp MLOS_CORE_CONF_FILES := mlos_core/pyproject.toml mlos_core/setup.py mlos_core/MANIFEST.in MLOS_BENCH_CONF_FILES := mlos_bench/pyproject.toml mlos_bench/setup.py mlos_bench/MANIFEST.in MLOS_VIZ_CONF_FILES := mlos_viz/pyproject.toml mlos_viz/setup.py mlos_viz/MANIFEST.in -MLOS_GLOBAL_CONF_FILES := setup.cfg # pyproject.toml +MLOS_GLOBAL_CONF_FILES := setup.cfg pyproject.toml MLOS_PKGS := mlos_core mlos_bench mlos_viz MLOS_PKG_CONF_FILES := $(MLOS_CORE_CONF_FILES) $(MLOS_BENCH_CONF_FILES) $(MLOS_VIZ_CONF_FILES) $(MLOS_GLOBAL_CONF_FILES) @@ -164,9 +164,7 @@ build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS) touch $@ .PHONY: check -check: pycodestyle pydocstyle pylint mypy # cspell markdown-link-check -# TODO: Enable isort and black checks -#check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check +check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check .PHONY: black-check black-check: build/black-check.mlos_core.${CONDA_ENV_NAME}.build-stamp @@ -723,7 +721,10 @@ clean-doc: .PHONY: clean-format clean-format: - # TODO: add black and isort rules + rm -f build/black.${CONDA_ENV_NAME}.build-stamp + rm -f build/black.mlos_*.${CONDA_ENV_NAME}.build-stamp + rm -f build/isort.${CONDA_ENV_NAME}.build-stamp + rm -f build/isort.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/licenseheaders.${CONDA_ENV_NAME}.build-stamp rm -f build/licenseheaders-prereqs.${CONDA_ENV_NAME}.build-stamp @@ -733,6 +734,11 @@ clean-check: rm -f build/pylint.${CONDA_ENV_NAME}.build-stamp rm -f build/pylint.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/mypy.mlos_*.${CONDA_ENV_NAME}.build-stamp + rm -f build/black-check.build-stamp + rm -f build/black-check.${CONDA_ENV_NAME}.build-stamp + rm -f build/black-check.mlos_*.${CONDA_ENV_NAME}.build-stamp + rm -f build/isort-check.${CONDA_ENV_NAME}.build-stamp + rm -f build/isort-check.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/pycodestyle.build-stamp rm -f build/pycodestyle.${CONDA_ENV_NAME}.build-stamp rm -f build/pycodestyle.mlos_*.${CONDA_ENV_NAME}.build-stamp From 96b233c49b1fe2efc0407d03f3130db95a2b8e25 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:12:59 +0000 Subject: [PATCH 02/54] remove autopep8 --- .devcontainer/devcontainer.json | 1 - conda-envs/mlos-3.10.yml | 1 - conda-envs/mlos-3.11.yml | 1 - conda-envs/mlos-3.8.yml | 1 - conda-envs/mlos-3.9.yml | 1 - conda-envs/mlos-windows.yml | 1 - conda-envs/mlos.yml | 1 - 7 files changed, 7 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 76a699a4d3..93eddacf6d 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -45,7 +45,6 @@ // Adjust the python interpreter path to point to the conda environment "python.defaultInterpreterPath": "/opt/conda/envs/mlos/bin/python", "python.testing.pytestPath": "/opt/conda/envs/mlos/bin/pytest", - "python.formatting.autopep8Path": "/opt/conda/envs/mlos/bin/autopep8", "python.linting.pylintPath": "/opt/conda/envs/mlos/bin/pylint", "pylint.path": [ "/opt/conda/envs/mlos/bin/pylint" diff --git a/conda-envs/mlos-3.10.yml b/conda-envs/mlos-3.10.yml index 5a63c126e8..4614b28a78 100644 --- a/conda-envs/mlos-3.10.yml +++ b/conda-envs/mlos-3.10.yml @@ -25,7 +25,6 @@ dependencies: # See comments in mlos.yml. #- gcc_linux-64 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort diff --git a/conda-envs/mlos-3.11.yml b/conda-envs/mlos-3.11.yml index 488731fd59..9680186660 100644 --- a/conda-envs/mlos-3.11.yml +++ b/conda-envs/mlos-3.11.yml @@ -25,7 +25,6 @@ dependencies: # See comments in mlos.yml. #- gcc_linux-64 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort diff --git a/conda-envs/mlos-3.8.yml b/conda-envs/mlos-3.8.yml index 88f43726af..1cfb0e18d2 100644 --- a/conda-envs/mlos-3.8.yml +++ b/conda-envs/mlos-3.8.yml @@ -25,7 +25,6 @@ dependencies: # See comments in mlos.yml. #- gcc_linux-64 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort diff --git a/conda-envs/mlos-3.9.yml b/conda-envs/mlos-3.9.yml index f1c27b7176..75cee3baee 100644 --- a/conda-envs/mlos-3.9.yml +++ b/conda-envs/mlos-3.9.yml @@ -25,7 +25,6 @@ dependencies: # See comments in mlos.yml. #- gcc_linux-64 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort diff --git a/conda-envs/mlos-windows.yml b/conda-envs/mlos-windows.yml index d1063b6613..190c2699e5 100644 --- a/conda-envs/mlos-windows.yml +++ b/conda-envs/mlos-windows.yml @@ -28,7 +28,6 @@ dependencies: # This also requires a more recent vs2015_runtime from conda-forge. - pyrfr>=0.9.0 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml index 0e6e052a51..51ce8077a8 100644 --- a/conda-envs/mlos.yml +++ b/conda-envs/mlos.yml @@ -24,7 +24,6 @@ dependencies: # FIXME: https://github.com/microsoft/MLOS/issues/727 - python<3.12 - pip: - - autopep8>=1.7.0 - bump2version - check-jsonschema - isort From 949f2cd04445f8caa3be336dcbf73ab3f7173c32 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:13:19 +0000 Subject: [PATCH 03/54] Enable black and isort formatters in vscode --- .devcontainer/devcontainer.json | 5 ++--- .vscode/extensions.json | 5 ++--- .vscode/settings.json | 10 +++------- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 93eddacf6d..1f56be9d1a 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -70,9 +70,8 @@ "lextudio.restructuredtext", "matangover.mypy", "ms-azuretools.vscode-docker", - // TODO: Enable additional formatter extensions: - //"ms-python.black-formatter", - //"ms-python.isort", + "ms-python.black-formatter", + "ms-python.isort", "ms-python.pylint", "ms-python.python", "ms-python.vscode-pylance", diff --git a/.vscode/extensions.json b/.vscode/extensions.json index 327bf5c51c..76dce33d5a 100644 --- a/.vscode/extensions.json +++ b/.vscode/extensions.json @@ -14,9 +14,8 @@ "lextudio.restructuredtext", "matangover.mypy", "ms-azuretools.vscode-docker", - // TODO: Enable additional formatter extensions: - //"ms-python.black-formatter", - //"ms-python.isort", + "ms-python.black-formatter", + "ms-python.isort", "ms-python.pylint", "ms-python.python", "ms-python.vscode-pylance", diff --git a/.vscode/settings.json b/.vscode/settings.json index 1e8eb58adb..6b9729290f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -125,14 +125,10 @@ ], "esbonio.sphinx.confDir": "${workspaceFolder}/doc/source", "esbonio.sphinx.buildDir": "${workspaceFolder}/doc/build/", - "autopep8.args": [ - "--experimental" - ], "[python]": { - // TODO: Enable black formatter - //"editor.defaultFormatter": "ms-python.black-formatter", - //"editor.formatOnSave": true, - //"editor.formatOnSaveMode": "modifications" + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.formatOnSaveMode": "modifications" }, // See Also .vscode/launch.json for environment variable args to pytest during debug sessions. // For the rest, see setup.cfg From ebd406eb4338dddf838967223957412f36b6bec0 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:13:35 +0000 Subject: [PATCH 04/54] adjust line length for other checkers --- .editorconfig | 4 ++++ .pylintrc | 2 +- setup.cfg | 4 ++-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/.editorconfig b/.editorconfig index e984d47595..c2b6ed65db 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,10 @@ charset = utf-8 # Note: this is not currently supported by all editors or their editorconfig plugins. max_line_length = 132 +# See Also: black configuration in pyproject.toml +[*.py] +max_line_length = 88 + # Makefiles need tab indentation [{Makefile,*.mk}] indent_style = tab diff --git a/.pylintrc b/.pylintrc index 6b308d1966..fdc93e2956 100644 --- a/.pylintrc +++ b/.pylintrc @@ -35,7 +35,7 @@ load-plugins= [FORMAT] # Maximum number of characters on a single line. -max-line-length=132 +max-line-length=88 [MESSAGE CONTROL] disable= diff --git a/setup.cfg b/setup.cfg index 492d2de7f2..661a7971cd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,7 +7,7 @@ count = True ignore = W503,W504 format = pylint # See Also: .editorconfig, .pylintrc -max-line-length = 132 +max-line-length = 88 show-source = True statistics = True @@ -25,7 +25,7 @@ match = .+(? Date: Wed, 3 Jul 2024 21:17:45 +0000 Subject: [PATCH 05/54] isort applied --- .../fio/scripts/local/process_fio_results.py | 1 - .../scripts/local/generate_redis_config.py | 2 +- .../scripts/local/generate_grub_config.py | 2 +- .../local/generate_kernel_config_script.py | 2 +- .../mlos_bench/config/schemas/__init__.py | 3 +- .../config/schemas/config_schemas.py | 5 +-- .../mlos_bench/environments/__init__.py | 9 ++-- .../environments/base_environment.py | 15 ++++++- .../mlos_bench/environments/composite_env.py | 6 +-- .../environments/local/local_env.py | 8 ++-- .../environments/local/local_fileshare_env.py | 9 ++-- .../mlos_bench/environments/mock_env.py | 6 +-- .../environments/remote/host_env.py | 3 +- .../environments/remote/network_env.py | 7 ++-- .../mlos_bench/environments/remote/os_env.py | 3 +- .../environments/remote/remote_env.py | 4 +- .../environments/remote/saas_env.py | 3 +- .../mlos_bench/environments/script_env.py | 1 - mlos_bench/mlos_bench/event_loop_context.py | 10 ++--- mlos_bench/mlos_bench/launcher.py | 20 +++------ mlos_bench/mlos_bench/optimizers/__init__.py | 2 +- .../mlos_bench/optimizers/base_optimizer.py | 9 ++-- .../optimizers/convert_configspace.py | 4 +- .../optimizers/grid_search_optimizer.py | 11 +++-- .../optimizers/mlos_core_optimizer.py | 21 +++++----- .../mlos_bench/optimizers/mock_optimizer.py | 8 ++-- .../optimizers/one_shot_optimizer.py | 2 +- .../optimizers/track_best_optimizer.py | 5 +-- mlos_bench/mlos_bench/os_environ.py | 2 +- .../mlos_bench/schedulers/base_scheduler.py | 5 +-- mlos_bench/mlos_bench/services/__init__.py | 3 +- .../mlos_bench/services/base_fileshare.py | 1 - .../mlos_bench/services/base_service.py | 2 +- .../mlos_bench/services/config_persistence.py | 41 +++++++++++++------ .../mlos_bench/services/local/__init__.py | 1 - .../mlos_bench/services/local/local_exec.py | 12 +++++- .../services/remote/azure/__init__.py | 1 - .../services/remote/azure/azure_auth.py | 3 +- .../remote/azure/azure_deployment_services.py | 3 +- .../services/remote/azure/azure_fileshare.py | 7 ++-- .../remote/azure/azure_network_services.py | 9 ++-- .../services/remote/azure/azure_saas.py | 1 - .../remote/azure/azure_vm_services.py | 9 ++-- .../services/remote/ssh/__init__.py | 2 +- .../services/remote/ssh/ssh_fileshare.py | 5 +-- .../services/remote/ssh/ssh_host_service.py | 7 ++-- .../services/remote/ssh/ssh_service.py | 31 ++++++++++---- .../mlos_bench/services/types/__init__.py | 5 ++- .../services/types/config_loader_type.py | 17 ++++++-- .../services/types/host_ops_type.py | 2 +- .../services/types/host_provisioner_type.py | 2 +- .../services/types/local_exec_type.py | 13 ++++-- .../types/network_provisioner_type.py | 2 +- .../mlos_bench/services/types/os_ops_type.py | 2 +- .../services/types/remote_config_type.py | 2 +- .../services/types/remote_exec_type.py | 2 +- .../services/types/vm_provisioner_type.py | 2 +- .../storage/base_experiment_data.py | 8 ++-- mlos_bench/mlos_bench/storage/base_storage.py | 3 +- .../mlos_bench/storage/base_trial_data.py | 8 ++-- .../base_tunable_config_trial_group_data.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 8 ++-- .../mlos_bench/storage/sql/experiment.py | 9 ++-- .../mlos_bench/storage/sql/experiment_data.py | 11 +++-- mlos_bench/mlos_bench/storage/sql/schema.py | 19 +++++++-- mlos_bench/mlos_bench/storage/sql/storage.py | 6 +-- mlos_bench/mlos_bench/storage/sql/trial.py | 6 +-- .../mlos_bench/storage/sql/trial_data.py | 12 ++++-- .../sql/tunable_config_trial_group_data.py | 6 ++- .../mlos_bench/storage/storage_factory.py | 2 +- mlos_bench/mlos_bench/tests/__init__.py | 12 +++--- .../mlos_bench/tests/config/__init__.py | 3 +- .../cli/test_load_cli_config_examples.py | 12 +++--- .../test_load_environment_config_examples.py | 4 +- .../test_load_global_config_examples.py | 4 +- .../test_load_optimizer_config_examples.py | 6 +-- .../tests/config/schemas/__init__.py | 3 +- .../config/schemas/cli/test_cli_schemas.py | 10 ++--- .../environments/test_environment_schemas.py | 13 +++--- .../schemas/globals/test_globals_schemas.py | 7 ++-- .../optimizers/test_optimizer_schemas.py | 17 ++++---- .../schedulers/test_scheduler_schemas.py | 13 +++--- .../schemas/services/test_services_schemas.py | 17 ++++---- .../schemas/storage/test_storage_schemas.py | 13 +++--- .../test_tunable_params_schemas.py | 7 ++-- .../test_tunable_values_schemas.py | 7 ++-- .../test_load_service_config_examples.py | 4 +- .../test_load_storage_config_examples.py | 4 +- mlos_bench/mlos_bench/tests/conftest.py | 12 +++--- .../mlos_bench/tests/environments/__init__.py | 1 - .../tests/environments/base_env_test.py | 2 +- .../composite_env_service_test.py | 2 +- .../tests/environments/composite_env_test.py | 2 +- .../local/composite_local_env_test.py | 6 +-- .../local/local_env_stdout_test.py | 2 +- .../local/local_env_telemetry_test.py | 7 ++-- .../environments/local/local_env_test.py | 2 +- .../environments/local/local_env_vars_test.py | 2 +- .../local/local_fileshare_env_test.py | 7 ++-- .../tests/environments/remote/test_ssh_env.py | 9 ++-- .../tests/event_loop_context_test.py | 3 +- .../tests/launcher_parse_args_test.py | 6 +-- .../mlos_bench/tests/launcher_run_test.py | 2 +- .../mlos_bench/tests/optimizers/conftest.py | 5 +-- .../optimizers/grid_search_optimizer_test.py | 4 +- .../tests/optimizers/llamatune_opt_test.py | 3 +- .../tests/optimizers/mlos_core_opt_df_test.py | 3 +- .../optimizers/mlos_core_opt_smac_test.py | 8 ++-- .../optimizers/opt_bulk_register_test.py | 2 +- .../optimizers/toy_optimization_loop_test.py | 15 +++---- .../tests/services/config_persistence_test.py | 2 +- .../services/local/local_exec_python_test.py | 7 ++-- .../tests/services/local/local_exec_test.py | 4 +- .../local/mock/mock_local_exec_service.py | 11 ++++- .../tests/services/remote/azure/__init__.py | 2 +- .../remote/azure/azure_fileshare_test.py | 2 +- .../azure/azure_network_services_test.py | 2 - .../remote/azure/azure_vm_services_test.py | 2 - .../tests/services/remote/azure/conftest.py | 2 +- .../remote/mock/mock_fileshare_service.py | 2 +- .../remote/mock/mock_network_service.py | 4 +- .../services/remote/mock/mock_vm_service.py | 2 +- .../tests/services/remote/ssh/__init__.py | 1 - .../tests/services/remote/ssh/fixtures.py | 20 ++++----- .../services/remote/ssh/test_ssh_fileshare.py | 10 ++--- .../remote/ssh/test_ssh_host_service.py | 16 ++++---- .../services/remote/ssh/test_ssh_service.py | 22 ++++++---- .../test_service_method_registering.py | 7 +++- .../mlos_bench/tests/storage/exp_data_test.py | 5 +-- .../mlos_bench/tests/storage/exp_load_test.py | 5 +-- .../mlos_bench/tests/storage/sql/fixtures.py | 11 +++-- .../tests/storage/trial_schedule_test.py | 1 - .../tests/storage/trial_telemetry_test.py | 8 ++-- .../tunable_config_trial_group_data_test.py | 3 +- .../mlos_bench/tests/test_with_alt_tz.py | 3 +- .../tests/tunable_groups_fixtures.py | 3 +- .../tunables/test_tunables_size_props.py | 1 - .../tunable_to_configspace_distr_test.py | 8 ++-- .../tunables/tunable_to_configspace_test.py | 5 +-- .../mlos_bench/tunables/covariant_group.py | 1 - mlos_bench/mlos_bench/tunables/tunable.py | 17 ++++++-- .../mlos_bench/tunables/tunable_groups.py | 3 +- mlos_bench/mlos_bench/util.py | 22 ++++++---- mlos_bench/setup.py | 7 ++-- mlos_core/mlos_core/optimizers/__init__.py | 6 +-- .../bayesian_optimizers/__init__.py | 5 ++- .../bayesian_optimizers/bayesian_optimizer.py | 3 +- .../bayesian_optimizers/smac_optimizer.py | 20 ++++++--- .../mlos_core/optimizers/flaml_optimizer.py | 7 +++- mlos_core/mlos_core/optimizers/optimizer.py | 2 +- .../mlos_core/spaces/adapters/llamatune.py | 9 ++-- .../mlos_core/spaces/converters/flaml.py | 6 +-- mlos_core/mlos_core/tests/__init__.py | 1 - .../optimizers/bayesian_optimizers_test.py | 5 +-- .../mlos_core/tests/optimizers/conftest.py | 3 +- .../tests/optimizers/one_hot_test.py | 7 ++-- .../optimizers/optimizer_multiobj_test.py | 10 ++--- .../tests/optimizers/optimizer_test.py | 25 ++++++----- .../tests/spaces/adapters/llamatune_test.py | 3 +- .../adapters/space_adapter_factory_test.py | 10 +++-- .../mlos_core/tests/spaces/spaces_test.py | 14 +++---- mlos_core/mlos_core/util.py | 2 +- mlos_core/setup.py | 5 +-- mlos_viz/mlos_viz/__init__.py | 4 +- mlos_viz/mlos_viz/base.py | 9 ++-- mlos_viz/mlos_viz/dabl.py | 4 +- mlos_viz/mlos_viz/tests/__init__.py | 1 - mlos_viz/mlos_viz/tests/test_base_plot.py | 11 ++--- mlos_viz/mlos_viz/tests/test_dabl_plot.py | 5 +-- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 5 +-- mlos_viz/setup.py | 8 ++-- 171 files changed, 596 insertions(+), 551 deletions(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index 679d0d4ceb..c32dea9bf6 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -10,7 +10,6 @@ import argparse import itertools import json - from typing import Any, Iterator, Tuple import pandas diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index c1850f5e03..949b9f9d91 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -9,8 +9,8 @@ Run: `./generate_redis_config.py ./input-params.json ./output-redis.cfg` """ -import json import argparse +import json def _main(fname_input: str, fname_output: str) -> None: diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index d03e4f5771..de344d61fb 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -9,8 +9,8 @@ Run: `./generate_grub_config.py ./input-boot-params.json ./output-grub.cfg` """ -import json import argparse +import json def _main(fname_input: str, fname_output: str) -> None: diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index e6d8039729..85a49a1817 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -9,8 +9,8 @@ Run: `./generate_kernel_config_script.py ./kernel-params.json ./kernel-params-meta.json ./config-kernel.sh` """ -import json import argparse +import json def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index 73daf81c3b..fa3b63e2e6 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -6,8 +6,7 @@ A module for managing config schemas and their validation. """ -from mlos_bench.config.schemas.config_schemas import ConfigSchema, CONFIG_SCHEMA_DIR - +from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ 'ConfigSchema', diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 9c4a066be5..82cbcacce2 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -6,14 +6,13 @@ A simple class for describing where to find different config schemas and validating configs against them. """ +import json # schema files are pure json - no comments import logging from enum import Enum -from os import path, walk, environ +from os import environ, path, walk from typing import Dict, Iterator, Mapping -import json # schema files are pure json - no comments import jsonschema - from referencing import Registry, Resource from referencing.jsonschema import DRAFT202012 diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index 9ed5480908..a1ccadae5f 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -6,14 +6,13 @@ Tunable Environments for mlos_bench. """ -from mlos_bench.environments.status import Status from mlos_bench.environments.base_environment import Environment - -from mlos_bench.environments.mock_env import MockEnv -from mlos_bench.environments.remote.remote_env import RemoteEnv +from mlos_bench.environments.composite_env import CompositeEnv from mlos_bench.environments.local.local_env import LocalEnv from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv -from mlos_bench.environments.composite_env import CompositeEnv +from mlos_bench.environments.mock_env import MockEnv +from mlos_bench.environments.remote.remote_env import RemoteEnv +from mlos_bench.environments.status import Status __all__ = [ 'Status', diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 508d78589b..61fbd69f50 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -11,10 +11,21 @@ import logging from datetime import datetime from types import TracebackType -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type, TYPE_CHECKING, Union -from typing_extensions import Literal +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) from pytz import UTC +from typing_extensions import Literal from mlos_bench.config.schemas import ConfigSchema from mlos_bench.dict_templater import DictTemplater diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 06b4f431be..a71b8ab9be 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -8,14 +8,14 @@ import logging from datetime import datetime - from types import TracebackType from typing import Any, Dict, List, Optional, Tuple, Type + from typing_extensions import Literal -from mlos_bench.services.base_service import Service -from mlos_bench.environments.status import Status from mlos_bench.environments.base_environment import Environment +from mlos_bench.environments.status import Status +from mlos_bench.services.base_service import Service from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index 01f0337c1f..da20f5c961 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -9,20 +9,18 @@ import json import logging import sys - +from contextlib import nullcontext from datetime import datetime from tempfile import TemporaryDirectory -from contextlib import nullcontext - from types import TracebackType from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union -from typing_extensions import Literal import pandas +from typing_extensions import Literal -from mlos_bench.environments.status import Status from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.script_env import ScriptEnv +from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service from mlos_bench.services.types.local_exec_type import SupportsLocalExec from mlos_bench.tunables.tunable import TunableValue diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 6aea7acfc4..174afd387c 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -8,16 +8,15 @@ """ import logging - from datetime import datetime from string import Template -from typing import Any, Dict, List, Generator, Iterable, Mapping, Optional, Tuple +from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple +from mlos_bench.environments.local.local_env import LocalEnv +from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service -from mlos_bench.services.types.local_exec_type import SupportsLocalExec from mlos_bench.services.types.fileshare_type import SupportsFileShareOps -from mlos_bench.environments.status import Status -from mlos_bench.environments.local.local_env import LocalEnv +from mlos_bench.services.types.local_exec_type import SupportsLocalExec from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index d8ffe3e47d..cc47b95500 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -6,16 +6,16 @@ Scheduler-side environment to mock the benchmark results. """ -import random import logging +import random from datetime import datetime from typing import Dict, Optional, Tuple import numpy -from mlos_bench.services.base_service import Service -from mlos_bench.environments.status import Status from mlos_bench.environments.base_environment import Environment +from mlos_bench.environments.status import Status +from mlos_bench.services.base_service import Service from mlos_bench.tunables import Tunable, TunableGroups, TunableValue _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index 4b63e47278..05896c9e60 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -6,9 +6,8 @@ Remote host Environment. """ -from typing import Optional - import logging +from typing import Optional from mlos_bench.environments.base_environment import Environment from mlos_bench.services.base_service import Service diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index d049beacfd..552f1729d9 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -6,13 +6,14 @@ Network Environment. """ -from typing import Optional - import logging +from typing import Optional from mlos_bench.environments.base_environment import Environment from mlos_bench.services.base_service import Service -from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning +from mlos_bench.services.types.network_provisioner_type import ( + SupportsNetworkProvisioning, +) from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index bb5a0238a4..ef733c77c2 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -6,9 +6,8 @@ OS-level remote Environment on Azure. """ -from typing import Optional - import logging +from typing import Optional from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 0320b02769..cf38a57b01 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -14,11 +14,11 @@ from pytz import UTC -from mlos_bench.environments.status import Status from mlos_bench.environments.script_env import ScriptEnv +from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service -from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec from mlos_bench.services.types.host_ops_type import SupportsHostOps +from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index 96a91db292..b661bfad7e 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -6,9 +6,8 @@ Cloud-based (configurable) SaaS environment. """ -from typing import Optional - import logging +from typing import Optional from mlos_bench.environments.base_environment import Environment from mlos_bench.services.base_service import Service diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 05c5fdec86..129ac21a0f 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -15,7 +15,6 @@ from mlos_bench.services.base_service import Service from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.util import try_parse_val _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index e618c18b16..4555ab7f50 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -6,14 +6,14 @@ EventLoopContext class definition. """ -from asyncio import AbstractEventLoop -from concurrent.futures import Future -from typing import Any, Coroutine, Optional, TypeVar -from threading import Lock as ThreadLock, Thread - import asyncio import logging import sys +from asyncio import AbstractEventLoop +from concurrent.futures import Future +from threading import Lock as ThreadLock +from threading import Thread +from typing import Any, Coroutine, Optional, TypeVar if sys.version_info >= (3, 10): from typing import TypeAlias diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index a9aa9e3f46..c8e48dab69 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -13,31 +13,23 @@ import argparse import logging import sys - from typing import Any, Dict, Iterable, List, Optional, Tuple from mlos_bench.config.schemas import ConfigSchema from mlos_bench.dict_templater import DictTemplater -from mlos_bench.util import try_parse_val - -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.environments.base_environment import Environment - from mlos_bench.optimizers.base_optimizer import Optimizer from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer - -from mlos_bench.storage.base_storage import Storage - +from mlos_bench.schedulers.base_scheduler import Scheduler from mlos_bench.services.base_service import Service -from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.services.config_persistence import ConfigPersistenceService - -from mlos_bench.schedulers.base_scheduler import Scheduler - +from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.services.types.config_loader_type import SupportsConfigLoading - +from mlos_bench.storage.base_storage import Storage +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO _LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index f875917251..f10fa3c82e 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -7,9 +7,9 @@ """ from mlos_bench.optimizers.base_optimizer import Optimizer +from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer -from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer __all__ = [ 'Optimizer', diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index 911c624315..b9df1db1b7 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -9,20 +9,19 @@ import logging from abc import ABCMeta, abstractmethod -from distutils.util import strtobool # pylint: disable=deprecated-module - +from distutils.util import strtobool # pylint: disable=deprecated-module from types import TracebackType from typing import Dict, Optional, Sequence, Tuple, Type, Union -from typing_extensions import Literal from ConfigSpace import ConfigurationSpace +from typing_extensions import Literal from mlos_bench.config.schemas import ConfigSchema -from mlos_bench.services.base_service import Service from mlos_bench.environments.status import Status +from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace +from mlos_bench.services.base_service import Service from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.optimizers.convert_configspace import tunable_groups_to_configspace _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 6978a8d410..62341c613d 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -7,7 +7,6 @@ """ import logging - from typing import Dict, List, Optional, Tuple, Union from ConfigSpace import ( @@ -21,9 +20,10 @@ Normal, Uniform, ) + from mlos_bench.tunables.tunable import Tunable, TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.util import try_parse_val, nullable +from mlos_bench.util import nullable, try_parse_val _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 0e836212d7..4f207f5fc9 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -7,19 +7,18 @@ """ import logging +from typing import Dict, Iterable, Optional, Sequence, Set, Tuple -from typing import Dict, Iterable, Set, Optional, Sequence, Tuple - -import numpy as np import ConfigSpace +import numpy as np from ConfigSpace.util import generate_grid from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer from mlos_bench.optimizers.convert_configspace import configspace_data_to_tunable_values +from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer from mlos_bench.services.base_service import Service +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index e0235f76b9..d7d50f1ca5 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -8,28 +8,29 @@ import logging import os - from types import TracebackType from typing import Dict, Optional, Sequence, Tuple, Type, Union -from typing_extensions import Literal import pandas as pd - -from mlos_core.optimizers import ( - BaseOptimizer, OptimizerType, OptimizerFactory, SpaceAdapterType, DEFAULT_OPTIMIZER_TYPE -) +from typing_extensions import Literal from mlos_bench.environments.status import Status -from mlos_bench.services.base_service import Service -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizers.base_optimizer import Optimizer - from mlos_bench.optimizers.convert_configspace import ( TunableValueKind, configspace_data_to_tunable_values, special_param_names, ) +from mlos_bench.services.base_service import Service +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_core.optimizers import ( + DEFAULT_OPTIMIZER_TYPE, + BaseOptimizer, + OptimizerFactory, + OptimizerType, + SpaceAdapterType, +) _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 7d2caff8ff..ada4411b58 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -6,17 +6,15 @@ Mock optimizer for mlos_bench. """ -import random import logging - +import random from typing import Callable, Dict, Optional, Sequence from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable import Tunable, TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.optimizers.track_best_optimizer import TrackBestOptimizer from mlos_bench.services.base_service import Service +from mlos_bench.tunables.tunable import Tunable, TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index 088ed03bdf..9ad1070c46 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -9,9 +9,9 @@ import logging from typing import Optional +from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.services.base_service import Service from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.optimizers.mock_optimizer import MockOptimizer _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index c5d07ab93d..32a23142e3 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -11,11 +11,10 @@ from typing import Dict, Optional, Tuple, Union from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.optimizers.base_optimizer import Optimizer from mlos_bench.services.base_service import Service +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index f3c556c06a..a7912688a1 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -29,7 +29,7 @@ # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 if sys.platform == 'win32': - import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) + import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) environ: EnvironType = nt.environ else: environ: EnvironType = os.environ diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 160cb2224b..0b6733e423 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -8,14 +8,13 @@ import json import logging -from datetime import datetime - from abc import ABCMeta, abstractmethod +from datetime import datetime from types import TracebackType from typing import Any, Dict, Optional, Tuple, Type -from typing_extensions import Literal from pytz import UTC +from typing_extensions import Literal from mlos_bench.environments.base_environment import Environment from mlos_bench.optimizers.base_optimizer import Optimizer diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index 89e71be815..bcc7d02d6f 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -6,11 +6,10 @@ Services for implementing Environments for mlos_bench. """ -from mlos_bench.services.base_service import Service from mlos_bench.services.base_fileshare import FileShareService +from mlos_bench.services.base_service import Service from mlos_bench.services.local.local_exec import LocalExecService - __all__ = [ 'Service', 'FileShareService', diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index b7282089e4..f00a7a1a00 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -7,7 +7,6 @@ """ import logging - from abc import ABCMeta, abstractmethod from typing import Any, Callable, Dict, List, Optional, Union diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index b171568172..e7c9365bf7 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -8,9 +8,9 @@ import json import logging - from types import TracebackType from typing import Any, Callable, Dict, List, Optional, Set, Type, Union + from typing_extensions import Literal from mlos_bench.config.schemas import ConfigSchema diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 4329d8f7e3..cac3216d61 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -8,16 +8,24 @@ service functions. """ +import json # For logging only +import logging import os import sys - -import json # For logging only -import logging - -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, TYPE_CHECKING - -import json5 # To read configs with comments and other JSON5 syntax features -from jsonschema import ValidationError, SchemaError +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) + +import json5 # To read configs with comments and other JSON5 syntax features +from jsonschema import SchemaError, ValidationError from mlos_bench.config.schemas import ConfigSchema from mlos_bench.environments.base_environment import Environment @@ -26,7 +34,12 @@ from mlos_bench.services.types.config_loader_type import SupportsConfigLoading from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.util import instantiate_from_config, merge_parameters, path_join, preprocess_dynamic_configs +from mlos_bench.util import ( + instantiate_from_config, + merge_parameters, + path_join, + preprocess_dynamic_configs, +) if sys.version_info < (3, 10): from importlib_resources import files @@ -34,8 +47,8 @@ from importlib.resources import files if TYPE_CHECKING: - from mlos_bench.storage.base_storage import Storage from mlos_bench.schedulers.base_scheduler import Scheduler + from mlos_bench.storage.base_storage import Storage _LOG = logging.getLogger(__name__) @@ -296,7 +309,9 @@ def build_storage(self, *, A new instance of the Storage class. """ (class_name, class_config) = self.prepare_class_load(config, global_config) - from mlos_bench.storage.base_storage import Storage # pylint: disable=import-outside-toplevel + from mlos_bench.storage.base_storage import ( + Storage, # pylint: disable=import-outside-toplevel + ) inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] config=class_config, global_config=global_config, @@ -335,7 +350,9 @@ def build_scheduler(self, *, A new instance of the Scheduler. """ (class_name, class_config) = self.prepare_class_load(config, global_config) - from mlos_bench.schedulers.base_scheduler import Scheduler # pylint: disable=import-outside-toplevel + from mlos_bench.schedulers.base_scheduler import ( + Scheduler, # pylint: disable=import-outside-toplevel + ) inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] config=class_config, global_config=global_config, diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index f35ea4c7e8..abb87c8b52 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -8,7 +8,6 @@ from mlos_bench.services.local.local_exec import LocalExecService - __all__ = [ 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 2ca567dfd4..47534be7b1 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -12,10 +12,18 @@ import shlex import subprocess import sys - from string import Template from typing import ( - Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, ) from mlos_bench.os_environ import environ diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 741593d035..61a6c74942 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -12,7 +12,6 @@ from mlos_bench.services.remote.azure.azure_saas import AzureSaaSConfigService from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService - __all__ = [ 'AzureAuthService', 'AzureFileShareService', diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index b1e484c009..4121446caf 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -11,10 +11,9 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union -from pytz import UTC - import azure.identity as azure_id from azure.keyvault.secrets import SecretClient +from pytz import UTC from mlos_bench.services.base_service import Service from mlos_bench.services.types.authenticator_type import SupportsAuth diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 187b7c055b..9f2b504aff 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -8,9 +8,8 @@ import abc import json -import time import logging - +import time from typing import Any, Callable, Dict, List, Optional, Tuple, Union import requests diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index af09f4c723..6ccd4ba09d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -6,16 +6,15 @@ A collection FileShare functions for interacting with Azure File Shares. """ -import os import logging - +import os from typing import Any, Callable, Dict, List, Optional, Set, Union -from azure.storage.fileshare import ShareClient from azure.core.exceptions import ResourceNotFoundError +from azure.storage.fileshare import ShareClient -from mlos_bench.services.base_service import Service from mlos_bench.services.base_fileshare import FileShareService +from mlos_bench.services.base_service import Service from mlos_bench.util import check_required_params _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 081d5d842e..d65ee02cfd 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -7,13 +7,16 @@ """ import logging - from typing import Any, Callable, Dict, List, Optional, Tuple, Union from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service -from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService -from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning +from mlos_bench.services.remote.azure.azure_deployment_services import ( + AzureDeploymentService, +) +from mlos_bench.services.types.network_provisioner_type import ( + SupportsNetworkProvisioning, +) from mlos_bench.util import merge_parameters _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index 34bec7d25e..a92d279a6d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -6,7 +6,6 @@ A collection Service functions for configuring SaaS instances on Azure. """ import logging - from typing import Any, Callable, Dict, List, Optional, Tuple, Union import requests diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 7fdbdc18df..ddce3cc935 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -8,18 +8,19 @@ import json import logging - from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import requests from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service -from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService -from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec -from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning +from mlos_bench.services.remote.azure.azure_deployment_services import ( + AzureDeploymentService, +) from mlos_bench.services.types.host_ops_type import SupportsHostOps +from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning from mlos_bench.services.types.os_ops_type import SupportsOSOps +from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec from mlos_bench.util import merge_parameters _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py index 2ab1705a74..cd897649ec 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/__init__.py @@ -4,8 +4,8 @@ # """SSH remote service.""" -from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService __all__ = [ "SshHostService", diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f753947aa7..f623cdfcc8 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -6,12 +6,11 @@ A collection functions for interacting with SSH servers as file shares. """ +import logging from enum import Enum from typing import Tuple, Union -import logging - -from asyncssh import scp, SFTPError, SFTPNoSuchFile, SFTPFailure, SSHClientConnection +from asyncssh import SFTPError, SFTPFailure, SFTPNoSuchFile, SSHClientConnection, scp from mlos_bench.services.base_fileshare import FileShareService from mlos_bench.services.remote.ssh.ssh_service import SshService diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index 40c84e6300..a650ff0707 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -6,18 +6,17 @@ A collection Service functions for managing hosts via SSH. """ +import logging from concurrent.futures import Future from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import logging - -from asyncssh import SSHCompletedProcess, ConnectionLost, DisconnectError, ProcessError +from asyncssh import ConnectionLost, DisconnectError, ProcessError, SSHCompletedProcess from mlos_bench.environments.status import Status from mlos_bench.services.base_service import Service from mlos_bench.services.remote.ssh.ssh_service import SshService -from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec from mlos_bench.services.types.os_ops_type import SupportsOSOps +from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec from mlos_bench.util import merge_parameters _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 50ab07d4d2..8bc90eb3da 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -6,21 +6,36 @@ A collection functions for interacting with SSH servers as file shares. """ -from abc import ABCMeta -from asyncio import Event as CoroEvent, Lock as CoroLock -from warnings import warn -from types import TracebackType -from typing import Any, Callable, Coroutine, Dict, List, Literal, Optional, Tuple, Type, Union -from threading import current_thread - import logging import os +from abc import ABCMeta +from asyncio import Event as CoroEvent +from asyncio import Lock as CoroLock +from threading import current_thread +from types import TracebackType +from typing import ( + Any, + Callable, + Coroutine, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + Union, +) +from warnings import warn import asyncssh from asyncssh.connection import SSHClientConnection +from mlos_bench.event_loop_context import ( + CoroReturnType, + EventLoopContext, + FutureReturnType, +) from mlos_bench.services.base_service import Service -from mlos_bench.event_loop_context import EventLoopContext, CoroReturnType, FutureReturnType from mlos_bench.util import nullable _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 2a9cbe3248..725d0c3306 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -11,11 +11,12 @@ from mlos_bench.services.types.fileshare_type import SupportsFileShareOps from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning from mlos_bench.services.types.local_exec_type import SupportsLocalExec -from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning +from mlos_bench.services.types.network_provisioner_type import ( + SupportsNetworkProvisioning, +) from mlos_bench.services.types.remote_config_type import SupportsRemoteConfig from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec - __all__ = [ 'SupportsAuth', 'SupportsConfigLoading', diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 401e4c6720..05853da0a9 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -6,17 +6,26 @@ Protocol interface for helper functions to lookup and load configs. """ -from typing import Any, Dict, List, Iterable, Optional, Union, Protocol, runtime_checkable, TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Protocol, + Union, + runtime_checkable, +) from mlos_bench.config.schemas import ConfigSchema from mlos_bench.tunables.tunable import TunableValue - # Avoid's circular import issues. if TYPE_CHECKING: - from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.services.base_service import Service from mlos_bench.environments.base_environment import Environment + from mlos_bench.services.base_service import Service + from mlos_bench.tunables.tunable_groups import TunableGroups @runtime_checkable diff --git a/mlos_bench/mlos_bench/services/types/host_ops_type.py b/mlos_bench/mlos_bench/services/types/host_ops_type.py index a5d0b5b036..5418f8b1d3 100644 --- a/mlos_bench/mlos_bench/services/types/host_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/host_ops_type.py @@ -6,7 +6,7 @@ Protocol interface for Host/VM boot operations. """ -from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index b3560783fc..77b481e48e 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -6,7 +6,7 @@ Protocol interface for Host/VM provisioning operations. """ -from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 1a3808bb61..c4c5f01ddc 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -7,10 +7,17 @@ scripts and commands locally on the scheduler side. """ -from typing import Iterable, Mapping, Optional, Tuple, Union, Protocol, runtime_checkable - -import tempfile import contextlib +import tempfile +from typing import ( + Iterable, + Mapping, + Optional, + Protocol, + Tuple, + Union, + runtime_checkable, +) from mlos_bench.tunables.tunable import TunableValue diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 5b6a9a6936..fb753aa21c 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -6,7 +6,7 @@ Protocol interface for Network provisioning operations. """ -from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/os_ops_type.py b/mlos_bench/mlos_bench/services/types/os_ops_type.py index ba36c6914a..6d5cea34e5 100644 --- a/mlos_bench/mlos_bench/services/types/os_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/os_ops_type.py @@ -6,7 +6,7 @@ Protocol interface for Host/OS operations. """ -from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index 4008aff576..c653e10c2b 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -6,7 +6,7 @@ Protocol interface for configuring cloud services. """ -from typing import Any, Dict, Protocol, Tuple, TYPE_CHECKING, runtime_checkable +from typing import TYPE_CHECKING, Any, Dict, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index 8dd41e51a8..096cb3c675 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -7,7 +7,7 @@ scripts on a remote host OS. """ -from typing import Iterable, Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py index 0574d25c61..19747b3f12 100644 --- a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py @@ -6,7 +6,7 @@ Protocol interface for VM provisioning operations. """ -from typing import Tuple, Protocol, runtime_checkable, TYPE_CHECKING +from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable if TYPE_CHECKING: from mlos_bench.environments.status import Status diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index b112a7b575..ce07e44e2b 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -7,8 +7,8 @@ """ from abc import ABCMeta, abstractmethod -from distutils.util import strtobool # pylint: disable=deprecated-module -from typing import Dict, Literal, Optional, Tuple, TYPE_CHECKING +from distutils.util import strtobool # pylint: disable=deprecated-module +from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple import pandas @@ -16,7 +16,9 @@ if TYPE_CHECKING: from mlos_bench.storage.base_trial_data import TrialData - from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData + from mlos_bench.storage.base_tunable_config_trial_group_data import ( + TunableConfigTrialGroupData, + ) class ExperimentData(metaclass=ABCMeta): diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 350b1b8ec8..2165fa706f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -10,7 +10,8 @@ from abc import ABCMeta, abstractmethod from datetime import datetime from types import TracebackType -from typing import Optional, List, Tuple, Dict, Iterator, Type, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple, Type + from typing_extensions import Literal from mlos_bench.config.schemas import ConfigSchema diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index 93d1b62c9b..b3b2bed86a 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -7,18 +7,20 @@ """ from abc import ABCMeta, abstractmethod from datetime import datetime -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional import pandas from pytz import UTC from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable import TunableValue from mlos_bench.storage.base_tunable_config_data import TunableConfigData from mlos_bench.storage.util import kv_df_to_dict +from mlos_bench.tunables.tunable import TunableValue if TYPE_CHECKING: - from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData + from mlos_bench.storage.base_tunable_config_trial_group_data import ( + TunableConfigTrialGroupData, + ) class TrialData(metaclass=ABCMeta): diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index b64524fb85..18c50035a9 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -7,7 +7,7 @@ """ from abc import ABCMeta, abstractmethod -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional import pandas diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index a3895f065a..c7ee73a3bc 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -8,13 +8,13 @@ from typing import Dict, Optional import pandas -from sqlalchemy import Engine, Integer, func, and_, select +from sqlalchemy import Engine, Integer, and_, func, select from mlos_bench.environments.status import Status from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.storage.sql.schema import DbSchema -from mlos_bench.util import utcify_timestamp, utcify_nullable_timestamp +from mlos_bench.util import utcify_nullable_timestamp, utcify_timestamp def get_trials( @@ -27,7 +27,9 @@ def get_trials( restricted by tunable_config_id. Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. """ - from mlos_bench.storage.sql.trial_data import TrialSqlData # pylint: disable=import-outside-toplevel,cyclic-import + from mlos_bench.storage.sql.trial_data import ( + TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import + ) with engine.connect() as conn: # Build up sql a statement for fetching trials. stmt = schema.trial.select().where( diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 721971bfeb..58ee3dddb5 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -6,20 +6,19 @@ Saving and restoring the benchmark data using SQLAlchemy. """ -import logging import hashlib +import logging from datetime import datetime -from typing import Optional, Tuple, List, Literal, Dict, Iterator, Any +from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple from pytz import UTC - -from sqlalchemy import Engine, Connection, CursorResult, Table, column, func, select +from sqlalchemy import Connection, CursorResult, Engine, Table, column, func, select from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.trial import Trial +from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import nullable, utcify_timestamp _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index 31b1d64af0..eaa6e1041f 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -5,9 +5,8 @@ """ An interface to access the experiment benchmark data stored in SQL DB. """ -from typing import Dict, Literal, Optional - import logging +from typing import Dict, Literal, Optional import pandas from sqlalchemy import Engine, Integer, String, func @@ -15,11 +14,15 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.storage.base_tunable_config_data import TunableConfigData -from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData +from mlos_bench.storage.base_tunable_config_trial_group_data import ( + TunableConfigTrialGroupData, +) from mlos_bench.storage.sql import common from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData -from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData +from mlos_bench.storage.sql.tunable_config_trial_group_data import ( + TunableConfigTrialGroupSqlData, +) _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index c59adc1c67..9a1eca2744 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -7,12 +7,23 @@ """ import logging -from typing import List, Any +from typing import Any, List from sqlalchemy import ( - Engine, MetaData, Dialect, create_mock_engine, - Table, Column, Sequence, Integer, Float, String, DateTime, - PrimaryKeyConstraint, ForeignKeyConstraint, UniqueConstraint, + Column, + DateTime, + Dialect, + Engine, + Float, + ForeignKeyConstraint, + Integer, + MetaData, + PrimaryKeyConstraint, + Sequence, + String, + Table, + UniqueConstraint, + create_mock_engine, ) _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index 1bfe695300..bde38575bd 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -11,13 +11,13 @@ from sqlalchemy import URL, create_engine -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.services.base_service import Service +from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_storage import Storage -from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.experiment import Experiment -from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.sql.experiment_data import ExperimentSqlData +from mlos_bench.storage.sql.schema import DbSchema +from mlos_bench.tunables.tunable_groups import TunableGroups _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 4806056e05..7ac7958845 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -8,15 +8,15 @@ import logging from datetime import datetime -from typing import List, Literal, Optional, Tuple, Dict, Any +from typing import Any, Dict, List, Literal, Optional, Tuple -from sqlalchemy import Engine, Connection +from sqlalchemy import Connection, Engine from sqlalchemy.exc import IntegrityError from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.sql.schema import DbSchema +from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import nullable, utcify_timestamp _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 7353e96e79..5a6f8a5ee8 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -6,20 +6,22 @@ An interface to access the benchmark trial data stored in SQL DB. """ from datetime import datetime -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import pandas from sqlalchemy import Engine +from mlos_bench.environments.status import Status from mlos_bench.storage.base_trial_data import TrialData from mlos_bench.storage.base_tunable_config_data import TunableConfigData -from mlos_bench.environments.status import Status from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData from mlos_bench.util import utcify_timestamp if TYPE_CHECKING: - from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData + from mlos_bench.storage.base_tunable_config_trial_group_data import ( + TunableConfigTrialGroupData, + ) class TrialSqlData(TrialData): @@ -63,7 +65,9 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": Retrieve the trial's tunable config group configuration data from the storage. """ # pylint: disable=import-outside-toplevel - from mlos_bench.storage.sql.tunable_config_trial_group_data import TunableConfigTrialGroupSqlData + from mlos_bench.storage.sql.tunable_config_trial_group_data import ( + TunableConfigTrialGroupSqlData, + ) return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, experiment_id=self._experiment_id, tunable_config_id=self._tunable_config_id) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index 775683133d..eb389a5940 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -6,13 +6,15 @@ An interface to access the tunable config trial group data stored in SQL DB. """ -from typing import Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Dict, Optional import pandas from sqlalchemy import Engine, Integer, func from mlos_bench.storage.base_tunable_config_data import TunableConfigData -from mlos_bench.storage.base_tunable_config_trial_group_data import TunableConfigTrialGroupData +from mlos_bench.storage.base_tunable_config_trial_group_data import ( + TunableConfigTrialGroupData, +) from mlos_bench.storage.sql import common from mlos_bench.storage.sql.schema import DbSchema from mlos_bench.storage.sql.tunable_config_data import TunableConfigSqlData diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index faa934b28f..220f3d812c 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -6,7 +6,7 @@ Factory method to create a new Storage instance from configs. """ -from typing import Any, Optional, List, Dict +from typing import Any, Dict, List, Optional from mlos_bench.config.schemas import ConfigSchema from mlos_bench.services.config_persistence import ConfigPersistenceService diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index d1c1781ada..26aa142441 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -6,22 +6,20 @@ Tests for mlos_bench. Used to make mypy happy about multiple conftest.py modules. """ +import filecmp +import os +import shutil +import socket from datetime import tzinfo from logging import debug, warning from subprocess import run from typing import List, Optional -import filecmp -import os -import socket -import shutil - -import pytz import pytest +import pytz from mlos_bench.util import get_class_from_name, nullable - ZONE_NAMES = [ # Explicit time zones. "UTC", diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 75b2ae0cbe..4d728b4037 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -6,10 +6,9 @@ Helper functions for config example loading tests. """ -from typing import Callable, List, Optional - import os import sys +from typing import Callable, List, Optional from mlos_bench.util import path_join diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 6fb341ff44..e1e26d7d8b 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -6,23 +6,21 @@ Tests for loading storage config examples. """ -from typing import List - import logging import sys +from typing import List import pytest -from mlos_bench.tests import check_class_name -from mlos_bench.tests.config import locate_config_examples, BUILTIN_TEST_CONFIG_PATH - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.environments import Environment +from mlos_bench.launcher import Launcher from mlos_bench.optimizers import Optimizer -from mlos_bench.storage import Storage from mlos_bench.schedulers import Scheduler from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.launcher import Launcher +from mlos_bench.storage import Storage +from mlos_bench.tests import check_class_name +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples from mlos_bench.util import path_join if sys.version_info < (3, 10): diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index d2f975fd85..42925a0a5d 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -10,15 +10,13 @@ import pytest -from mlos_bench.tests.config import locate_config_examples - from mlos_bench.config.schemas.config_schemas import ConfigSchema from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.composite_env import CompositeEnv from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.tests.config import locate_config_examples from mlos_bench.tunables.tunable_groups import TunableGroups - _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index 4d3a2602b5..4d8c93fdff 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -10,11 +10,9 @@ import pytest -from mlos_bench.tests.config import locate_config_examples, BUILTIN_TEST_CONFIG_PATH - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.services.config_persistence import ConfigPersistenceService - +from mlos_bench.tests.config import BUILTIN_TEST_CONFIG_PATH, locate_config_examples _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index bd9099b608..6cb6253dea 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -10,15 +10,13 @@ import pytest -from mlos_bench.tests.config import locate_config_examples - from mlos_bench.config.schemas import ConfigSchema -from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.optimizers.base_optimizer import Optimizer +from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.tests.config import locate_config_examples from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import get_class_from_name - _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index 7e6edacbb3..e4264003e1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -6,12 +6,11 @@ Common tests for config schemas and their validation and test cases. """ +import os from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Set -import os - import json5 import jsonschema import pytest diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index ffc0add973..5dd1666008 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -11,11 +11,11 @@ import pytest from mlos_bench.config.schemas import ConfigSchema - -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 8d1c5135d0..dc3cd40425 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -10,18 +10,17 @@ import pytest -from mlos_core.tests import get_all_concrete_subclasses - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.composite_env import CompositeEnv from mlos_bench.environments.script_env import ScriptEnv - from mlos_bench.tests import try_resolve_class_name -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) +from mlos_core.tests import get_all_concrete_subclasses # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 59d3ddd866..5045bf510b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -11,9 +11,10 @@ import pytest from mlos_bench.config.schemas import ConfigSchema - -from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + get_schema_test_cases, +) # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index e69c50c4bd..e9ee653644 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -11,18 +11,17 @@ import pytest -from mlos_core.optimizers import OptimizerType -from mlos_core.spaces.adapters import SpaceAdapterType -from mlos_core.tests import get_all_concrete_subclasses - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.optimizers.base_optimizer import Optimizer - from mlos_bench.tests import try_resolve_class_name -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) +from mlos_core.optimizers import OptimizerType +from mlos_core.spaces.adapters import SpaceAdapterType +from mlos_core.tests import get_all_concrete_subclasses # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 6e625b8ef2..8fccba8bc7 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -10,16 +10,15 @@ import pytest -from mlos_core.tests import get_all_concrete_subclasses - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.schedulers.base_scheduler import Scheduler - from mlos_bench.tests import try_resolve_class_name -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) +from mlos_core.tests import get_all_concrete_subclasses # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index c96346ad7b..64c6fccccd 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -11,20 +11,21 @@ import pytest -from mlos_core.tests import get_all_concrete_subclasses - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.temp_dir_context import TempDirContextService -from mlos_bench.services.remote.azure.azure_deployment_services import AzureDeploymentService +from mlos_bench.services.remote.azure.azure_deployment_services import ( + AzureDeploymentService, +) from mlos_bench.services.remote.ssh.ssh_service import SshService - from mlos_bench.tests import try_resolve_class_name -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) +from mlos_core.tests import get_all_concrete_subclasses # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 7c42b85c4b..9b362b5e0d 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -10,16 +10,15 @@ import pytest -from mlos_core.tests import get_all_concrete_subclasses - from mlos_bench.config.schemas import ConfigSchema from mlos_bench.storage.base_storage import Storage - from mlos_bench.tests import try_resolve_class_name -from mlos_bench.tests.config.schemas import (get_schema_test_cases, - check_test_case_against_schema, - check_test_case_config_with_extra_param) - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + check_test_case_config_with_extra_param, + get_schema_test_cases, +) +from mlos_core.tests import get_all_concrete_subclasses # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index fda78a19f9..a6d0de9313 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -11,9 +11,10 @@ import pytest from mlos_bench.config.schemas import ConfigSchema - -from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + get_schema_test_cases, +) # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 043bb725bc..d871eaa212 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -11,9 +11,10 @@ import pytest from mlos_bench.config.schemas import ConfigSchema - -from mlos_bench.tests.config.schemas import get_schema_test_cases, check_test_case_against_schema - +from mlos_bench.tests.config.schemas import ( + check_test_case_against_schema, + get_schema_test_cases, +) # General testing strategy: # - hand code a set of good/bad configs (useful to test editor schema checking) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index f3da324dee..32034eb11c 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -10,12 +10,10 @@ import pytest -from mlos_bench.tests.config import locate_config_examples - from mlos_bench.config.schemas.config_schemas import ConfigSchema from mlos_bench.services.base_service import Service from mlos_bench.services.config_persistence import ConfigPersistenceService - +from mlos_bench.tests.config import locate_config_examples _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index 039d49948f..2f9773a9b0 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -10,14 +10,12 @@ import pytest -from mlos_bench.tests.config import locate_config_examples - from mlos_bench.config.schemas.config_schemas import ConfigSchema from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.storage.base_storage import Storage +from mlos_bench.tests.config import locate_config_examples from mlos_bench.util import get_class_from_name - _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index f30fe43585..58359eb983 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -6,19 +6,17 @@ Common fixtures for mock TunableGroups and Environment objects. """ -from typing import Any, Generator, List - import os - -from fasteners import InterProcessLock, InterProcessReaderWriterLock -from pytest_docker.plugin import get_docker_services, Services as DockerServices +from typing import Any, Generator, List import pytest +from fasteners import InterProcessLock, InterProcessReaderWriterLock +from pytest_docker.plugin import Services as DockerServices +from pytest_docker.plugin import get_docker_services from mlos_bench.environments.mock_env import MockEnv -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.tests import SEED, tunable_groups_fixtures +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name # -- Ignore pylint complaints about pytest references to diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index e33188a9e3..ac0b942167 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -12,7 +12,6 @@ import pytest from mlos_bench.environments.base_environment import Environment - from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 69253e31c1..8afb8e5cda 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -10,8 +10,8 @@ import pytest -from mlos_bench.tunables.tunable import TunableValue from mlos_bench.environments.base_environment import Environment +from mlos_bench.tunables.tunable import TunableValue _GROUPS = { "group": ["a", "b"], diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index e135a868ef..6497eb6985 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -10,9 +10,9 @@ import pytest from mlos_bench.environments.composite_env import CompositeEnv -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.local_exec import LocalExecService +from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.util import path_join # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index fd7c022939..742eaf3c79 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -9,8 +9,8 @@ import pytest from mlos_bench.environments.composite_env import CompositeEnv -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 4815e1c50d..9bcb7aa218 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -9,13 +9,13 @@ from datetime import datetime, timedelta, tzinfo from typing import Optional -from pytz import UTC import pytest +from pytz import UTC -from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.tests import ZONE_INFO from mlos_bench.tests.environments import check_env_success from mlos_bench.tests.environments.local import create_composite_local_env -from mlos_bench.tests import ZONE_INFO +from mlos_bench.tunables.tunable_groups import TunableGroups def _format_str(zone_info: Optional[tzinfo]) -> str: diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index 7b4de8c237..20854b9f9e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -8,9 +8,9 @@ import sys -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tests.environments import check_env_success from mlos_bench.tests.environments.local import create_local_env +from mlos_bench.tunables.tunable_groups import TunableGroups def test_local_env_stdout(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index ba104da542..35bdb39486 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -8,14 +8,13 @@ from datetime import datetime, timedelta, tzinfo from typing import Optional -from pytz import UTC - import pytest +from pytz import UTC -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tests import ZONE_INFO -from mlos_bench.tests.environments import check_env_success, check_env_fail_telemetry +from mlos_bench.tests.environments import check_env_fail_telemetry, check_env_success from mlos_bench.tests.environments.local import create_local_env +from mlos_bench.tunables.tunable_groups import TunableGroups def _format_str(zone_info: Optional[tzinfo]) -> str: diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 9fcd26ead2..6cb4fd4f7e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -7,9 +7,9 @@ """ import pytest -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tests.environments import check_env_success from mlos_bench.tests.environments.local import create_local_env +from mlos_bench.tunables.tunable_groups import TunableGroups def test_local_env(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index ac7ff257e1..c16eac4459 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -9,9 +9,9 @@ import pytest -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tests.environments import check_env_success from mlos_bench.tests.environments.local import create_local_env +from mlos_bench.tunables.tunable_groups import TunableGroups def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index bb455b8b76..8bce053f7b 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -7,12 +7,13 @@ """ import pytest -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.local.local_exec import LocalExecService - -from mlos_bench.tests.services.remote.mock.mock_fileshare_service import MockFileShareService +from mlos_bench.tests.services.remote.mock.mock_fileshare_service import ( + MockFileShareService, +) +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 36ea7c324b..878531d799 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -6,22 +6,19 @@ Unit tests for RemoveEnv benchmark environment via local SSH test services. """ -from typing import Dict - import os import sys +from typing import Dict import numpy as np - import pytest from mlos_bench.services.config_persistence import ConfigPersistenceService -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.tests import requires_docker from mlos_bench.tests.environments import check_env_success from mlos_bench.tests.services.remote.ssh import SshTestServerInfo +from mlos_bench.tunables.tunable import TunableValue +from mlos_bench.tunables.tunable_groups import TunableGroups if sys.version_info < (3, 10): from importlib_resources import files diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index 80b252f255..377bc940a0 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -9,14 +9,13 @@ import asyncio import sys import time - from asyncio import AbstractEventLoop from threading import Thread from types import TracebackType from typing import Optional, Type -from typing_extensions import Literal import pytest +from typing_extensions import Literal from mlos_bench.event_loop_context import EventLoopContext diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 90e52bb880..634050d099 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -14,11 +14,10 @@ import pytest +from mlos_bench.config.schemas import ConfigSchema from mlos_bench.launcher import Launcher -from mlos_bench.optimizers import OneShotOptimizer, MlosCoreOptimizer +from mlos_bench.optimizers import MlosCoreOptimizer, OneShotOptimizer from mlos_bench.os_environ import environ -from mlos_bench.config.schemas import ConfigSchema -from mlos_bench.util import path_join from mlos_bench.schedulers import SyncScheduler from mlos_bench.services.types import ( SupportsAuth, @@ -28,6 +27,7 @@ SupportsRemoteExec, ) from mlos_bench.tests import check_class_name +from mlos_bench.util import path_join if sys.version_info < (3, 10): from importlib_resources import files diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index d8caf7537e..591501d275 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -11,8 +11,8 @@ import pytest -from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.util import path_join # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 7149a79c93..59a0fac13b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -10,11 +10,10 @@ import pytest -from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer - +from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.tests import SEED +from mlos_bench.tunables.tunable_groups import TunableGroups @pytest.fixture diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 9e43b3731e..9e9ce25d6f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -6,11 +6,10 @@ Unit tests for grid search mlos_bench optimizer. """ -from typing import Dict, List - import itertools import math import random +from typing import Dict, List import pytest @@ -19,7 +18,6 @@ from mlos_bench.tunables.tunable import TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups - # pylint: disable=redefined-outer-name @pytest.fixture diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index d356466e58..6549a8795c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -9,10 +9,9 @@ import pytest from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer - from mlos_bench.tests import SEED +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index f36e3c149c..7ebba0e664 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -12,9 +12,8 @@ import pytest from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.tests import SEED +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name, protected-access diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index b10571095b..fc62b4ff1b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -6,17 +6,15 @@ Unit tests for mock mlos_bench optimizer. """ import os -import sys import shutil +import sys import pytest -from mlos_bench.util import path_join from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.tests import SEED - +from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer _OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index f2805e9322..bf37040f13 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -12,8 +12,8 @@ from mlos_bench.environments.status import Status from mlos_bench.optimizers.base_optimizer import Optimizer -from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer +from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.tunables.tunable import TunableValue # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index 183db1dc62..2a50f95e8c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -6,23 +6,20 @@ Toy optimization loop to test the optimizers on mock benchmark environment. """ -from typing import Tuple - import logging +from typing import Tuple import pytest -from mlos_core.util import config_to_dataframe -from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -from mlos_bench.optimizers.convert_configspace import tunable_values_to_configuration - from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.mock_env import MockEnv -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizers.base_optimizer import Optimizer -from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.optimizers.convert_configspace import tunable_values_to_configuration from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer - +from mlos_bench.optimizers.mock_optimizer import MockOptimizer +from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer +from mlos_core.util import config_to_dataframe # For debugging purposes output some warnings which are captured with failed tests. DEBUG = True diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index 55dc15a8d4..d6cb869f09 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -8,13 +8,13 @@ import os import sys + import pytest from mlos_bench.config.schemas import ConfigSchema from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.util import path_join - if sys.version_info < (3, 9): from importlib_resources import files else: diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index 6f8549aee7..572195dcc5 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -6,15 +6,14 @@ Unit tests for LocalExecService to run Python scripts locally. """ -from typing import Any, Dict - import json +from typing import Any, Dict import pytest -from mlos_bench.tunables.tunable import TunableValue -from mlos_bench.services.local.local_exec import LocalExecService from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.services.local.local_exec import LocalExecService +from mlos_bench.tunables.tunable import TunableValue from mlos_bench.util import path_join # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index 6e56b3bbe2..bd5b3b7d7f 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -8,11 +8,11 @@ import sys import tempfile -import pytest import pandas +import pytest -from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline from mlos_bench.services.config_persistence import ConfigPersistenceService +from mlos_bench.services.local.local_exec import LocalExecService, split_cmdline from mlos_bench.util import path_join # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index 588b94d8ea..db8f0134c4 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -8,7 +8,16 @@ import logging from typing import ( - Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, Union + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Union, ) from mlos_bench.services.base_service import Service diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py index dc3b5469be..9bf6e49541 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py @@ -5,9 +5,9 @@ """ Tests helpers for mlos_bench.services.remote.azure. """ +import json from io import BytesIO -import json import urllib3 diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 199c42f1fb..949b712c79 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -7,7 +7,7 @@ """ import os -from unittest.mock import MagicMock, Mock, patch, call +from unittest.mock import MagicMock, Mock, call, patch from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index 22fec74c74..d6d55d3975 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -12,10 +12,8 @@ import requests.exceptions as requests_ex from mlos_bench.environments.status import Status - from mlos_bench.services.remote.azure.azure_auth import AzureAuthService from mlos_bench.services.remote.azure.azure_network_services import AzureNetworkService - from mlos_bench.tests.services.remote.azure import make_httplib_json_response diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 97bf904a56..1d84d73cab 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -13,10 +13,8 @@ import requests.exceptions as requests_ex from mlos_bench.environments.status import Status - from mlos_bench.services.remote.azure.azure_auth import AzureAuthService from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService - from mlos_bench.tests.services.remote.azure import make_httplib_json_response diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 45feb1aa50..2794bb01cf 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -13,9 +13,9 @@ from mlos_bench.services.config_persistence import ConfigPersistenceService from mlos_bench.services.remote.azure import ( AzureAuthService, + AzureFileShareService, AzureNetworkService, AzureVMService, - AzureFileShareService, ) # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index c09b31b299..1a026966a8 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -9,8 +9,8 @@ import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union -from mlos_bench.services.base_service import Service from mlos_bench.services.base_fileshare import FileShareService +from mlos_bench.services.base_service import Service from mlos_bench.services.types.fileshare_type import SupportsFileShareOps _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index 6d2bd058b9..e6169d9f93 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -9,7 +9,9 @@ from typing import Any, Callable, Dict, List, Optional, Union from mlos_bench.services.base_service import Service -from mlos_bench.services.types.network_provisioner_type import SupportsNetworkProvisioning +from mlos_bench.services.types.network_provisioner_type import ( + SupportsNetworkProvisioning, +) from mlos_bench.tests.services.remote.mock import mock_operation diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index 88896c3f16..a44edaf080 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -9,8 +9,8 @@ from typing import Any, Callable, Dict, List, Optional, Union from mlos_bench.services.base_service import Service -from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning from mlos_bench.services.types.host_ops_type import SupportsHostOps +from mlos_bench.services.types.host_provisioner_type import SupportsHostProvisioning from mlos_bench.services.types.os_ops_type import SupportsOSOps from mlos_bench.tests.services.remote.mock import mock_operation diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index eb285ffc7d..e0060d8047 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -14,7 +14,6 @@ from mlos_bench.tests import check_socket - # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 1706a42969..6f05fe953b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -8,25 +8,25 @@ Note: these are not in the conftest.py file because they are also used by remote_ssh_env_test.py """ -from typing import Generator -from subprocess import run - import os import sys import tempfile +from subprocess import run +from typing import Generator import pytest from pytest_docker.plugin import Services as DockerServices -from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService - +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.tests import resolve_host_name -from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo, - ALT_TEST_SERVER_NAME, - REBOOT_TEST_SERVER_NAME, - SSH_TEST_SERVER_NAME, - wait_docker_service_socket) +from mlos_bench.tests.services.remote.ssh import ( + ALT_TEST_SERVER_NAME, + REBOOT_TEST_SERVER_NAME, + SSH_TEST_SERVER_NAME, + SshTestServerInfo, + wait_docker_service_socket, +) # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index ab25093fd0..f2bbbe4b8a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -6,23 +6,21 @@ Tests for mlos_bench.services.remote.ssh.ssh_services """ +import os +import tempfile from contextlib import contextmanager from os.path import basename from pathlib import Path from tempfile import _TemporaryFileWrapper # pylint: disable=import-private-name from typing import Any, Dict, Generator, List -import os -import tempfile - import pytest -from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService -from mlos_bench.util import path_join - +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.tests import are_dir_trees_equal, requires_docker from mlos_bench.tests.services.remote.ssh import SshTestServerInfo +from mlos_bench.util import path_join @contextmanager diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 03b7eb56a8..4c8e5e0c66 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -6,22 +6,22 @@ Tests for mlos_bench.services.remote.ssh.ssh_host_service """ -from subprocess import CalledProcessError, run - import logging import time +from subprocess import CalledProcessError, run from pytest_docker.plugin import Services as DockerServices from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_service import SshClient - from mlos_bench.tests import requires_docker -from mlos_bench.tests.services.remote.ssh import (SshTestServerInfo, - ALT_TEST_SERVER_NAME, - REBOOT_TEST_SERVER_NAME, - SSH_TEST_SERVER_NAME, - wait_docker_service_socket) +from mlos_bench.tests.services.remote.ssh import ( + ALT_TEST_SERVER_NAME, + REBOOT_TEST_SERVER_NAME, + SSH_TEST_SERVER_NAME, + SshTestServerInfo, + wait_docker_service_socket, +) _LOG = logging.getLogger(__name__) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index fd0804ba15..7bee929fea 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -7,22 +7,28 @@ """ import asyncio -from importlib.metadata import version, PackageNotFoundError import time - +from importlib.metadata import PackageNotFoundError, version from subprocess import run from threading import Thread import pytest from pytest_lazy_fixtures.lazy_fixture import lf as lazy_fixture -from mlos_bench.services.remote.ssh.ssh_service import SshService -from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService from mlos_bench.services.remote.ssh.ssh_fileshare import SshFileShareService - -from mlos_bench.tests import requires_docker, requires_ssh, check_socket, resolve_host_name -from mlos_bench.tests.services.remote.ssh import SshTestServerInfo, ALT_TEST_SERVER_NAME, SSH_TEST_SERVER_NAME - +from mlos_bench.services.remote.ssh.ssh_host_service import SshHostService +from mlos_bench.services.remote.ssh.ssh_service import SshService +from mlos_bench.tests import ( + check_socket, + requires_docker, + requires_ssh, + resolve_host_name, +) +from mlos_bench.tests.services.remote.ssh import ( + ALT_TEST_SERVER_NAME, + SSH_TEST_SERVER_NAME, + SshTestServerInfo, +) if version("pytest") >= "8.0.0": try: diff --git a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py index 736a3d5ef2..463879634f 100644 --- a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py +++ b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py @@ -9,8 +9,11 @@ import pytest from mlos_bench.services.base_service import Service - -from mlos_bench.tests.services.mock_service import SupportsSomeMethod, MockServiceBase, MockServiceChild +from mlos_bench.tests.services.mock_service import ( + MockServiceBase, + MockServiceChild, + SupportsSomeMethod, +) def test_service_method_register_without_constructor() -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index c37dc433b0..8159043be1 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -6,11 +6,10 @@ Unit tests for loading the experiment metadata. """ -from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.base_experiment_data import ExperimentData -from mlos_bench.tunables.tunable_groups import TunableGroups - +from mlos_bench.storage.base_storage import Storage from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT +from mlos_bench.tunables.tunable_groups import TunableGroups def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index cd6b17be74..d0a5edc694 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -8,14 +8,13 @@ from datetime import datetime, tzinfo from typing import Optional -from pytz import UTC - import pytest +from pytz import UTC from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_storage import Storage from mlos_bench.tests import ZONE_INFO +from mlos_bench.tunables.tunable_groups import TunableGroups def test_exp_load_empty(exp_storage: Storage.Experiment) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index b86ebe6c18..7e346a5ccc 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -7,21 +7,20 @@ """ from datetime import datetime -from random import random, seed as rand_seed +from random import random +from random import seed as rand_seed from typing import Generator, Optional -from pytz import UTC - import pytest +from pytz import UTC from mlos_bench.environments.status import Status +from mlos_bench.optimizers.mock_optimizer import MockOptimizer from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.sql.storage import SqlStorage -from mlos_bench.optimizers.mock_optimizer import MockOptimizer -from mlos_bench.tunables.tunable_groups import TunableGroups - from mlos_bench.tests import SEED from mlos_bench.tests.storage import CONFIG_COUNT, CONFIG_TRIAL_REPEAT_COUNT +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index 21f857ae45..04f4f18ae3 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -6,7 +6,6 @@ Unit tests for scheduling trials for some future time. """ from datetime import datetime, timedelta - from typing import Iterator, Set from pytz import UTC diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index deea02128f..855c6cd861 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -8,16 +8,14 @@ from datetime import datetime, timedelta, tzinfo from typing import Any, List, Optional, Tuple -from pytz import UTC - import pytest +from pytz import UTC from mlos_bench.environments.status import Status -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_storage import Storage -from mlos_bench.util import nullable - from mlos_bench.tests import ZONE_INFO +from mlos_bench.tunables.tunable_groups import TunableGroups +from mlos_bench.util import nullable # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index 736621a3fd..d08b26e92d 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -6,10 +6,9 @@ Unit tests for loading the TunableConfigTrialGroupData. """ -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.storage.base_experiment_data import ExperimentData - from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT +from mlos_bench.tunables.tunable_groups import TunableGroups def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index cd279f4e5d..fa947610da 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -6,16 +6,15 @@ Tests various other test scenarios with alternative default (un-named) TZ info. """ -from subprocess import run import os import sys +from subprocess import run from typing import Optional import pytest from mlos_bench.tests import ZONE_NAMES - DIRNAME = os.path.dirname(__file__) TZ_TEST_FILES = [ DIRNAME + "/environments/local/composite_local_env_test.py", diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 0350fff3bb..822547b1da 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -8,9 +8,8 @@ from typing import Any, Dict -import pytest - import json5 as json +import pytest from mlos_bench.config.schemas import ConfigSchema from mlos_bench.tunables.covariant_group import CovariantTunableGroup diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index de966536c4..58bb0368b1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -11,7 +11,6 @@ from mlos_bench.tunables.tunable import Tunable - # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 1024ba992b..73e3a12caa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -8,24 +8,22 @@ """ import pytest - from ConfigSpace import ( - CategoricalHyperparameter, BetaFloatHyperparameter, BetaIntegerHyperparameter, + CategoricalHyperparameter, NormalFloatHyperparameter, NormalIntegerHyperparameter, UniformFloatHyperparameter, UniformIntegerHyperparameter, ) -from mlos_bench.tunables.tunable import DistributionName -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizers.convert_configspace import ( special_param_names, tunable_groups_to_configspace, ) - +from mlos_bench.tunables.tunable import DistributionName +from mlos_bench.tunables.tunable_groups import TunableGroups _CS_HYPERPARAMETER = { ("float", "beta"): BetaFloatHyperparameter, diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 42b24dd51e..78e91fd25e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -7,7 +7,6 @@ """ import pytest - from ConfigSpace import ( CategoricalHyperparameter, ConfigurationSpace, @@ -16,14 +15,14 @@ UniformIntegerHyperparameter, ) -from mlos_bench.tunables.tunable import Tunable -from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.optimizers.convert_configspace import ( TunableValueKind, _tunable_to_configspace, special_param_names, tunable_groups_to_configspace, ) +from mlos_bench.tunables.tunable import Tunable +from mlos_bench.tunables.tunable_groups import TunableGroups # pylint: disable=redefined-outer-name diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index beb8db15ec..fee4fd5841 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -6,7 +6,6 @@ Tunable parameter definition. """ import copy - from typing import Dict, Iterable, Union from mlos_bench.tunables.tunable import Tunable, TunableValue diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 26eb719866..1ebd70dfa4 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -5,11 +5,22 @@ """ Tunable parameter definition. """ -import copy import collections +import copy import logging - -from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Type, TypedDict, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + Union, +) import numpy as np diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index f97bf9de7d..0bd58c8269 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -6,12 +6,11 @@ TunableGroups definition. """ import copy - from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union from mlos_bench.config.schemas import ConfigSchema -from mlos_bench.tunables.tunable import Tunable, TunableValue from mlos_bench.tunables.covariant_group import CovariantTunableGroup +from mlos_bench.tunables.tunable import Tunable, TunableValue class TunableGroups: diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index eb7dd3990d..531988be97 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -8,22 +8,30 @@ # NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. -from datetime import datetime -import os +import importlib import json import logging -import importlib +import os import subprocess - +from datetime import datetime from typing import ( - Any, Callable, Dict, Iterable, Literal, Mapping, Optional, - Tuple, Type, TypeVar, TYPE_CHECKING, Union, + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Literal, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, ) import pandas import pytz - _LOG = logging.getLogger(__name__) if TYPE_CHECKING: diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index be59c22477..27d844c35b 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -8,12 +8,11 @@ # pylint: disable=duplicate-code -from logging import warning -from itertools import chain -from typing import Dict, List - import os import re +from itertools import chain +from logging import warning +from typing import Dict, List from setuptools import setup diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index b00a9e8eb1..086002af62 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -11,11 +11,11 @@ import ConfigSpace -from mlos_core.optimizers.optimizer import BaseOptimizer -from mlos_core.optimizers.random_optimizer import RandomOptimizer from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer from mlos_core.optimizers.flaml_optimizer import FlamlOptimizer -from mlos_core.spaces.adapters import SpaceAdapterType, SpaceAdapterFactory +from mlos_core.optimizers.optimizer import BaseOptimizer +from mlos_core.optimizers.random_optimizer import RandomOptimizer +from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ 'SpaceAdapterType', diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index 55f0aa09eb..5f32219988 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -6,10 +6,11 @@ Basic initializer module for the mlos_core Bayesian optimizers. """ -from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import BaseBayesianOptimizer +from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import ( + BaseBayesianOptimizer, +) from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer - __all__ = [ 'BaseBayesianOptimizer', 'SmacOptimizer', diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 2de01637f8..76ff0d9b3a 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -7,11 +7,10 @@ """ from abc import ABCMeta, abstractmethod - from typing import Optional -import pandas as pd import numpy.typing as npt +import pandas as pd from mlos_core.optimizers.optimizer import BaseOptimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index aa948b8125..9d8d2a0347 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -9,15 +9,17 @@ from logging import warning from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union, TYPE_CHECKING from tempfile import TemporaryDirectory +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from warnings import warn import ConfigSpace import numpy.typing as npt import pandas as pd -from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import BaseBayesianOptimizer +from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import ( + BaseBayesianOptimizer, +) from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter from mlos_core.spaces.adapters.identity_adapter import IdentityAdapter @@ -259,7 +261,11 @@ def _register(self, *, configs: pd.DataFrame, metadata: pd.DataFrame Not Yet Implemented. """ - from smac.runhistory import StatusType, TrialInfo, TrialValue # pylint: disable=import-outside-toplevel + from smac.runhistory import ( # pylint: disable=import-outside-toplevel + StatusType, + TrialInfo, + TrialValue, + ) if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) @@ -292,7 +298,9 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr Not yet implemented. """ if TYPE_CHECKING: - from smac.runhistory import TrialInfo # pylint: disable=import-outside-toplevel,unused-import + from smac.runhistory import ( + TrialInfo, # pylint: disable=import-outside-toplevel,unused-import + ) if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) @@ -311,7 +319,9 @@ def register_pending(self, *, configs: pd.DataFrame, raise NotImplementedError() def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: - from smac.utils.configspace import convert_configurations_to_array # pylint: disable=import-outside-toplevel + from smac.utils.configspace import ( + convert_configurations_to_array, # pylint: disable=import-outside-toplevel + ) if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 4f478db2bf..273c89eecc 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -13,9 +13,9 @@ import numpy as np import pandas as pd -from mlos_core.util import normalize_config from mlos_core.optimizers.optimizer import BaseOptimizer from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter +from mlos_core.util import normalize_config class EvaluatedSample(NamedTuple): @@ -77,7 +77,10 @@ def __init__(self, *, # pylint: disable=too-many-arguments np.random.seed(seed) # pylint: disable=import-outside-toplevel - from mlos_core.spaces.converters.flaml import configspace_to_flaml_space, FlamlDomain + from mlos_core.spaces.converters.flaml import ( + FlamlDomain, + configspace_to_flaml_space, + ) self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) self.low_cost_partial_config = low_cost_partial_config diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 8fcf592a6c..4ab9db5a2f 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -15,8 +15,8 @@ import numpy.typing as npt import pandas as pd -from mlos_core.util import config_to_dataframe from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter +from mlos_core.util import config_to_dataframe class BaseOptimizer(metaclass=ABCMeta): diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 554b1169f5..4d3a925cbc 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -9,14 +9,14 @@ from warnings import warn import ConfigSpace -from ConfigSpace.hyperparameters import NumericalHyperparameter import numpy as np import numpy.typing as npt import pandas as pd +from ConfigSpace.hyperparameters import NumericalHyperparameter from sklearn.preprocessing import MinMaxScaler -from mlos_core.util import normalize_config from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter +from mlos_core.util import normalize_config class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes @@ -350,7 +350,10 @@ def _try_generate_approx_inverse_mapping(self) -> None: ------ RuntimeError: if reverse mapping computation fails. """ - from scipy.linalg import pinv, LinAlgError # pylint: disable=import-outside-toplevel + from scipy.linalg import ( # pylint: disable=import-outside-toplevel + LinAlgError, + pinv, + ) warn("Trying to register a configuration that was not previously suggested by the optimizer. " + "This inverse configuration transformation is typically not supported. " + diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 3935dbef6c..d6918f9891 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -6,15 +6,13 @@ Contains space converters for FLAML. """ -from typing import Dict, TYPE_CHECKING - import sys +from typing import TYPE_CHECKING, Dict import ConfigSpace -import numpy as np - import flaml.tune import flaml.tune.sample +import numpy as np if TYPE_CHECKING: from ConfigSpace.hyperparameters import Hyperparameter diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index 6f74147ae9..a8ad146205 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -7,7 +7,6 @@ """ import sys - from importlib import import_module from pkgutil import walk_packages from typing import List, Optional, Set, Type, TypeVar diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index c1aaa710ac..c7a94dfcc4 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -8,10 +8,9 @@ from typing import Optional, Type -import pytest - -import pandas as pd import ConfigSpace as CS +import pandas as pd +import pytest from mlos_core.optimizers import BaseOptimizer, OptimizerType from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index be1b658387..39231bec5c 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -6,9 +6,8 @@ Test fixtures for mlos_bench optimizers. """ -import pytest - import ConfigSpace as CS +import pytest @pytest.fixture diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 8e10afa302..725d92fbe9 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -6,12 +6,11 @@ Tests for one-hot encoding for certain optimizers. """ -import pytest - -import pandas as pd +import ConfigSpace as CS import numpy as np import numpy.typing as npt -import ConfigSpace as CS +import pandas as pd +import pytest from mlos_core.optimizers import BaseOptimizer, SmacOptimizer diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 22263b4c1d..0b9d624a7a 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -9,14 +9,12 @@ import logging from typing import List, Optional, Type -import pytest - -import pandas as pd -import numpy as np import ConfigSpace as CS +import numpy as np +import pandas as pd +import pytest -from mlos_core.optimizers import OptimizerType, BaseOptimizer - +from mlos_core.optimizers import BaseOptimizer, OptimizerType from mlos_core.tests import SEED _LOG = logging.getLogger(__name__) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 8231e59feb..5fd28ca1ed 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -6,24 +6,27 @@ Tests for Bayesian Optimizers. """ +import logging from copy import deepcopy from typing import List, Optional, Type -import logging -import pytest - -import pandas as pd -import numpy as np import ConfigSpace as CS +import numpy as np +import pandas as pd +import pytest from mlos_core.optimizers import ( - OptimizerType, ConcreteOptimizer, OptimizerFactory, BaseOptimizer) - -from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer, SmacOptimizer + BaseOptimizer, + ConcreteOptimizer, + OptimizerFactory, + OptimizerType, +) +from mlos_core.optimizers.bayesian_optimizers import ( + BaseBayesianOptimizer, + SmacOptimizer, +) from mlos_core.spaces.adapters import SpaceAdapterType - -from mlos_core.tests import get_all_concrete_subclasses, SEED - +from mlos_core.tests import SEED, get_all_concrete_subclasses _LOG = logging.getLogger(__name__) _LOG.setLevel(logging.DEBUG) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 661decc288..84dcd4e5c0 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -10,10 +10,9 @@ from typing import Any, Dict, Iterator, List, Set -import pytest - import ConfigSpace as CS import pandas as pd +import pytest from mlos_core.spaces.adapters import LlamaTuneAdapter diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 4f0c31538f..5390f97c5f 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -10,14 +10,16 @@ from typing import List, Optional, Type -import pytest - import ConfigSpace as CS +import pytest -from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType, ConcreteSpaceAdapter +from mlos_core.spaces.adapters import ( + ConcreteSpaceAdapter, + SpaceAdapterFactory, + SpaceAdapterType, +) from mlos_core.spaces.adapters.adapter import BaseSpaceAdapter from mlos_core.spaces.adapters.identity_adapter import IdentityAdapter - from mlos_core.tests import get_all_concrete_subclasses diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index f77852594e..dee9251652 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -11,19 +11,19 @@ from abc import ABCMeta, abstractmethod from typing import Any, Callable, List, NoReturn, Union +import ConfigSpace as CS +import flaml.tune.sample import numpy as np import numpy.typing as npt import pytest - import scipy - -import ConfigSpace as CS from ConfigSpace.hyperparameters import Hyperparameter, NormalIntegerHyperparameter -import flaml.tune.sample - -from mlos_core.spaces.converters.flaml import configspace_to_flaml_space, FlamlDomain, FlamlSpace - +from mlos_core.spaces.converters.flaml import ( + FlamlDomain, + FlamlSpace, + configspace_to_flaml_space, +) OptimizerSpace = Union[FlamlSpace, CS.ConfigurationSpace] OptimizerParam = Union[FlamlDomain, Hyperparameter] diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index 8acb654adf..df0e144535 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -8,8 +8,8 @@ from typing import Union -from ConfigSpace import Configuration, ConfigurationSpace import pandas as pd +from ConfigSpace import Configuration, ConfigurationSpace def config_to_dataframe(config: Configuration) -> pd.DataFrame: diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 519b213d90..fed376d1af 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -8,13 +8,12 @@ # pylint: disable=duplicate-code +import os +import re from itertools import chain from logging import warning from typing import Dict, List -import os -import re - from setuptools import setup PKG_NAME = "mlos_core" diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index a7ba74b1d7..2390554e1e 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -38,7 +38,7 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) """ base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: - import mlos_viz.dabl # pylint: disable=import-outside-toplevel + import mlos_viz.dabl # pylint: disable=import-outside-toplevel mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") @@ -80,7 +80,7 @@ def plot(exp_data: Optional[ExperimentData] = None, *, base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) if MlosVizMethod.DABL: - import mlos_viz.dabl # pylint: disable=import-outside-toplevel + import mlos_viz.dabl # pylint: disable=import-outside-toplevel mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 787315313a..15358b0862 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -6,23 +6,20 @@ Base functions for visualizing, explain, and gain insights from results. """ -from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union - import re import warnings - from importlib.metadata import version +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union -from matplotlib import pyplot as plt import pandas +import seaborn as sns +from matplotlib import pyplot as plt from pandas.api.types import is_numeric_dtype from pandas.core.groupby.generic import SeriesGroupBy -import seaborn as sns from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args - _SEABORN_VERS = version('seaborn') diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 112bf70470..504486a58c 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -5,15 +5,13 @@ """ Small wrapper functions for dabl plotting functions via mlos_bench data. """ -from typing import Dict, Optional, Literal - import warnings +from typing import Dict, Literal, Optional import dabl import pandas from mlos_bench.storage.base_experiment_data import ExperimentData - from mlos_viz.util import expand_results_data_args diff --git a/mlos_viz/mlos_viz/tests/__init__.py b/mlos_viz/mlos_viz/tests/__init__.py index d496cbe2b3..2aa5f430cf 100644 --- a/mlos_viz/mlos_viz/tests/__init__.py +++ b/mlos_viz/mlos_viz/tests/__init__.py @@ -10,7 +10,6 @@ import seaborn # pylint: disable=unused-import # (used by patch) # noqa: unused - BASE_MATPLOTLIB_SHOW_PATCH = "mlos_viz.base.plt.show" if sys.version_info >= (3, 11): diff --git a/mlos_viz/mlos_viz/tests/test_base_plot.py b/mlos_viz/mlos_viz/tests/test_base_plot.py index 9fb33471e6..52d571e742 100644 --- a/mlos_viz/mlos_viz/tests/test_base_plot.py +++ b/mlos_viz/mlos_viz/tests/test_base_plot.py @@ -7,13 +7,14 @@ """ import warnings - -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch from mlos_bench.storage.base_experiment_data import ExperimentData - -from mlos_viz.base import ignore_plotter_warnings, plot_optimizer_trends, plot_top_n_configs - +from mlos_viz.base import ( + ignore_plotter_warnings, + plot_optimizer_trends, + plot_top_n_configs, +) from mlos_viz.tests import BASE_MATPLOTLIB_SHOW_PATCH diff --git a/mlos_viz/mlos_viz/tests/test_dabl_plot.py b/mlos_viz/mlos_viz/tests/test_dabl_plot.py index 36c83b12f2..fc4dd3667a 100644 --- a/mlos_viz/mlos_viz/tests/test_dabl_plot.py +++ b/mlos_viz/mlos_viz/tests/test_dabl_plot.py @@ -7,13 +7,10 @@ """ import warnings - -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch from mlos_bench.storage.base_experiment_data import ExperimentData - from mlos_viz import dabl - from mlos_viz.tests import SEABORN_BOXPLOT_PATCH diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index 0be7220f47..06ac4a7664 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -8,13 +8,10 @@ import random import warnings - -from unittest.mock import patch, Mock +from unittest.mock import Mock, patch from mlos_bench.storage.base_experiment_data import ExperimentData - from mlos_viz import MlosVizMethod, plot - from mlos_viz.tests import BASE_MATPLOTLIB_SHOW_PATCH, SEABORN_BOXPLOT_PATCH diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 7c8c4deb07..98d12598e1 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -8,16 +8,14 @@ # pylint: disable=duplicate-code -from logging import warning -from itertools import chain -from typing import Dict, List - import os import re +from itertools import chain +from logging import warning +from typing import Dict, List from setuptools import setup - PKG_NAME = "mlos_viz" try: From 1b9843a587b77daab06df362a0b109b2712b91a2 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:22:23 +0000 Subject: [PATCH 06/54] black formatting --- .../fio/scripts/local/process_fio_results.py | 26 +- .../scripts/local/generate_redis_config.py | 12 +- .../scripts/local/process_redis_results.py | 30 +- .../boot/scripts/local/create_new_grub_cfg.py | 11 +- .../scripts/local/generate_grub_config.py | 10 +- .../local/generate_kernel_config_script.py | 5 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 21 +- mlos_bench/mlos_bench/dict_templater.py | 21 +- .../mlos_bench/environments/__init__.py | 15 +- .../environments/base_environment.py | 118 ++-- .../mlos_bench/environments/composite_env.py | 73 ++- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 139 +++-- .../environments/local/local_fileshare_env.py | 68 ++- .../mlos_bench/environments/mock_env.py | 52 +- .../environments/remote/__init__.py | 12 +- .../environments/remote/host_env.py | 37 +- .../environments/remote/network_env.py | 45 +- .../mlos_bench/environments/remote/os_env.py | 40 +- .../environments/remote/remote_env.py | 46 +- .../environments/remote/saas_env.py | 46 +- .../mlos_bench/environments/script_env.py | 45 +- mlos_bench/mlos_bench/event_loop_context.py | 14 +- mlos_bench/mlos_bench/launcher.py | 288 ++++++---- mlos_bench/mlos_bench/optimizers/__init__.py | 8 +- .../mlos_bench/optimizers/base_optimizer.py | 97 ++-- .../optimizers/convert_configspace.py | 122 +++-- .../optimizers/grid_search_optimizer.py | 91 +++- .../optimizers/mlos_core_optimizer.py | 100 ++-- .../mlos_bench/optimizers/mock_optimizer.py | 26 +- .../optimizers/one_shot_optimizer.py | 12 +- .../optimizers/track_best_optimizer.py | 26 +- mlos_bench/mlos_bench/os_environ.py | 11 +- mlos_bench/mlos_bench/run.py | 9 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 113 ++-- .../mlos_bench/schedulers/sync_scheduler.py | 4 +- mlos_bench/mlos_bench/services/__init__.py | 6 +- .../mlos_bench/services/base_fileshare.py | 43 +- .../mlos_bench/services/base_service.py | 68 ++- .../mlos_bench/services/config_persistence.py | 309 +++++++---- .../mlos_bench/services/local/__init__.py | 2 +- .../mlos_bench/services/local/local_exec.py | 59 +- .../services/local/temp_dir_context.py | 26 +- .../services/remote/azure/__init__.py | 10 +- .../services/remote/azure/azure_auth.py | 52 +- .../remote/azure/azure_deployment_services.py | 175 ++++-- .../services/remote/azure/azure_fileshare.py | 35 +- .../remote/azure/azure_network_services.py | 81 +-- .../services/remote/azure/azure_saas.py | 133 +++-- .../remote/azure/azure_vm_services.py | 269 ++++++---- .../services/remote/ssh/ssh_fileshare.py | 65 ++- .../services/remote/ssh/ssh_host_service.py | 109 ++-- .../services/remote/ssh/ssh_service.py | 185 +++++-- .../mlos_bench/services/types/__init__.py | 16 +- .../services/types/config_loader_type.py | 45 +- .../services/types/fileshare_type.py | 8 +- .../services/types/host_provisioner_type.py | 4 +- .../services/types/local_exec_type.py | 13 +- .../types/network_provisioner_type.py | 8 +- .../services/types/remote_config_type.py | 5 +- .../services/types/remote_exec_type.py | 5 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 19 +- mlos_bench/mlos_bench/storage/base_storage.py | 131 +++-- .../mlos_bench/storage/base_trial_data.py | 26 +- .../storage/base_tunable_config_data.py | 3 +- .../base_tunable_config_trial_group_data.py | 24 +- mlos_bench/mlos_bench/storage/sql/__init__.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 233 +++++--- .../mlos_bench/storage/sql/experiment.py | 281 ++++++---- .../mlos_bench/storage/sql/experiment_data.py | 105 ++-- mlos_bench/mlos_bench/storage/sql/schema.py | 52 +- mlos_bench/mlos_bench/storage/sql/storage.py | 29 +- mlos_bench/mlos_bench/storage/sql/trial.py | 146 +++-- .../mlos_bench/storage/sql/trial_data.py | 80 ++- .../storage/sql/tunable_config_data.py | 14 +- .../sql/tunable_config_trial_group_data.py | 43 +- .../mlos_bench/storage/storage_factory.py | 8 +- mlos_bench/mlos_bench/storage/util.py | 24 +- mlos_bench/mlos_bench/tests/__init__.py | 40 +- .../mlos_bench/tests/config/__init__.py | 12 +- .../cli/test_load_cli_config_examples.py | 89 ++- .../mlos_bench/tests/config/conftest.py | 14 +- .../test_load_environment_config_examples.py | 72 ++- .../test_load_global_config_examples.py | 8 +- .../test_load_optimizer_config_examples.py | 8 +- .../tests/config/schemas/__init__.py | 76 ++- .../config/schemas/cli/test_cli_schemas.py | 13 +- .../environments/test_environment_schemas.py | 42 +- .../schemas/globals/test_globals_schemas.py | 9 +- .../optimizers/test_optimizer_schemas.py | 89 ++- .../schedulers/test_scheduler_schemas.py | 42 +- .../schemas/services/test_services_schemas.py | 43 +- .../schemas/storage/test_storage_schemas.py | 52 +- .../test_tunable_params_schemas.py | 9 +- .../test_tunable_values_schemas.py | 9 +- .../test_load_service_config_examples.py | 14 +- .../test_load_storage_config_examples.py | 8 +- mlos_bench/mlos_bench/tests/conftest.py | 24 +- .../mlos_bench/tests/dict_templater_test.py | 4 +- .../mlos_bench/tests/environments/__init__.py | 14 +- .../tests/environments/base_env_test.py | 4 +- .../composite_env_service_test.py | 36 +- .../tests/environments/composite_env_test.py | 160 +++--- .../environments/include_tunables_test.py | 36 +- .../tests/environments/local/__init__.py | 20 +- .../local/composite_local_env_test.py | 23 +- .../local/local_env_stdout_test.py | 88 +-- .../local/local_env_telemetry_test.py | 149 +++--- .../environments/local/local_env_test.py | 73 +-- .../environments/local/local_env_vars_test.py | 61 ++- .../local/local_fileshare_env_test.py | 25 +- .../tests/environments/mock_env_test.py | 87 +-- .../tests/environments/remote/test_ssh_env.py | 18 +- .../tests/event_loop_context_test.py | 69 ++- .../tests/launcher_in_process_test.py | 40 +- .../tests/launcher_parse_args_test.py | 147 ++--- .../mlos_bench/tests/launcher_run_test.py | 107 ++-- .../mlos_bench/tests/optimizers/conftest.py | 40 +- .../optimizers/grid_search_optimizer_test.py | 144 +++-- .../tests/optimizers/llamatune_opt_test.py | 9 +- .../tests/optimizers/mlos_core_opt_df_test.py | 68 +-- .../optimizers/mlos_core_opt_smac_test.py | 96 ++-- .../tests/optimizers/mock_opt_test.py | 71 ++- .../optimizers/opt_bulk_register_test.py | 101 ++-- .../optimizers/toy_optimization_loop_test.py | 30 +- .../mlos_bench/tests/services/__init__.py | 8 +- .../tests/services/config_persistence_test.py | 50 +- .../tests/services/local/__init__.py | 2 +- .../services/local/local_exec_python_test.py | 15 +- .../tests/services/local/local_exec_test.py | 126 +++-- .../tests/services/local/mock/__init__.py | 2 +- .../local/mock/mock_local_exec_service.py | 26 +- .../mlos_bench/tests/services/mock_service.py | 23 +- .../tests/services/remote/__init__.py | 6 +- .../remote/azure/azure_fileshare_test.py | 166 ++++-- .../azure/azure_network_services_test.py | 99 ++-- .../remote/azure/azure_vm_services_test.py | 231 +++++--- .../tests/services/remote/azure/conftest.py | 108 ++-- .../services/remote/mock/mock_auth_service.py | 26 +- .../remote/mock/mock_fileshare_service.py | 25 +- .../remote/mock/mock_network_service.py | 35 +- .../remote/mock/mock_remote_exec_service.py | 26 +- .../services/remote/mock/mock_vm_service.py | 55 +- .../tests/services/remote/ssh/__init__.py | 18 +- .../tests/services/remote/ssh/fixtures.py | 67 ++- .../services/remote/ssh/test_ssh_fileshare.py | 48 +- .../remote/ssh/test_ssh_host_service.py | 102 ++-- .../services/remote/ssh/test_ssh_service.py | 65 ++- .../test_service_method_registering.py | 10 +- .../mlos_bench/tests/storage/conftest.py | 8 +- .../mlos_bench/tests/storage/exp_data_test.py | 82 ++- .../mlos_bench/tests/storage/exp_load_test.py | 80 +-- .../mlos_bench/tests/storage/sql/fixtures.py | 97 ++-- .../tests/storage/trial_config_test.py | 19 +- .../tests/storage/trial_schedule_test.py | 36 +- .../tests/storage/trial_telemetry_test.py | 49 +- .../tests/storage/tunable_config_data_test.py | 26 +- .../tunable_config_trial_group_data_test.py | 68 ++- .../mlos_bench/tests/test_with_alt_tz.py | 8 +- .../tests/tunable_groups_fixtures.py | 38 +- .../mlos_bench/tests/tunables/conftest.py | 47 +- .../tunables/test_tunable_categoricals.py | 2 +- .../tunables/test_tunables_size_props.py | 23 +- .../tests/tunables/tunable_comparison_test.py | 15 +- .../tests/tunables/tunable_definition_test.py | 108 ++-- .../tunables/tunable_distributions_test.py | 68 ++- .../tunables/tunable_group_indexing_test.py | 12 +- .../tunables/tunable_group_subgroup_test.py | 2 +- .../tunable_to_configspace_distr_test.py | 54 +- .../tunables/tunable_to_configspace_test.py | 59 +- .../tests/tunables/tunables_assign_test.py | 26 +- .../tests/tunables/tunables_str_test.py | 76 +-- mlos_bench/mlos_bench/tunables/__init__.py | 6 +- .../mlos_bench/tunables/covariant_group.py | 18 +- mlos_bench/mlos_bench/tunables/tunable.py | 124 +++-- .../mlos_bench/tunables/tunable_groups.py | 76 ++- mlos_bench/mlos_bench/util.py | 59 +- mlos_bench/mlos_bench/version.py | 2 +- mlos_bench/setup.py | 95 ++-- mlos_core/mlos_core/optimizers/__init__.py | 32 +- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 +- .../bayesian_optimizers/smac_optimizer.py | 179 +++++-- .../mlos_core/optimizers/flaml_optimizer.py | 78 ++- mlos_core/mlos_core/optimizers/optimizer.py | 176 ++++-- .../mlos_core/optimizers/random_optimizer.py | 47 +- .../mlos_core/spaces/adapters/__init__.py | 19 +- .../mlos_core/spaces/adapters/adapter.py | 10 +- .../mlos_core/spaces/adapters/llamatune.py | 209 +++++--- .../mlos_core/spaces/converters/flaml.py | 26 +- mlos_core/mlos_core/tests/__init__.py | 23 +- .../optimizers/bayesian_optimizers_test.py | 23 +- .../mlos_core/tests/optimizers/conftest.py | 10 +- .../tests/optimizers/one_hot_test.py | 96 ++-- .../optimizers/optimizer_multiobj_test.py | 80 +-- .../tests/optimizers/optimizer_test.py | 241 +++++---- .../spaces/adapters/identity_adapter_test.py | 25 +- .../tests/spaces/adapters/llamatune_test.py | 505 ++++++++++++------ .../adapters/space_adapter_factory_test.py | 60 ++- .../mlos_core/tests/spaces/spaces_test.py | 77 ++- mlos_core/mlos_core/util.py | 13 +- mlos_core/mlos_core/version.py | 2 +- mlos_core/setup.py | 71 +-- mlos_viz/mlos_viz/__init__.py | 23 +- mlos_viz/mlos_viz/base.py | 234 +++++--- mlos_viz/mlos_viz/dabl.py | 68 ++- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 23 +- mlos_viz/mlos_viz/version.py | 2 +- mlos_viz/setup.py | 53 +- 213 files changed, 7913 insertions(+), 4389 deletions(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index c32dea9bf6..75c72e6207 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -20,7 +20,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: Flatten every dict in the hierarchy and rename the keys with the dict path. """ if isinstance(data, dict): - for (key, val) in data.items(): + for key, val in data.items(): yield from _flat_dict(val, f"{path}.{key}") else: yield (path, data) @@ -30,13 +30,15 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: """ Convert FIO read data from JSON to tall CSV. """ - with open(input_file, mode='r', encoding='utf-8') as fh_input: + with open(input_file, mode="r", encoding="utf-8") as fh_input: json_data = json.load(fh_input) - data = list(itertools.chain( - _flat_dict(json_data["jobs"][0], prefix), - _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util") - )) + data = list( + itertools.chain( + _flat_dict(json_data["jobs"][0], prefix), + _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"), + ) + ) tall_df = pandas.DataFrame(data, columns=["metric", "value"]) tall_df.to_csv(output_file, index=False) @@ -49,12 +51,16 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.") parser.add_argument( - "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).") + "input", + help="FIO benchmark results in JSON format (downloaded from a remote VM).", + ) parser.add_argument( - "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).") + "output", + help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).", + ) parser.add_argument( - "--prefix", default="fio", - help="Prefix of the metric IDs (default 'fio')") + "--prefix", default="fio", help="Prefix of the metric IDs (default 'fio')" + ) args = parser.parse_args() _main(args.input, args.output, args.prefix) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index 949b9f9d91..d41f20d2a9 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -14,17 +14,19 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): - line = f'{key} {val}' + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): + line = f"{key} {val}" fh_config.write(line + "\n") print(line) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate Redis config from tunable parameters JSON.") + description="generate Redis config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output Redis config file.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index e33c717953..eb0b904c5d 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -21,18 +21,21 @@ def _main(input_file: str, output_file: str) -> None: # Format the results from wide to long # The target is columns of metric and value to act as key-value pairs. df_long = ( - df_wide - .melt(id_vars=["test"]) + df_wide.melt(id_vars=["test"]) .assign(metric=lambda df: df["test"] + "_" + df["variable"]) .drop(columns=["test", "variable"]) .loc[:, ["metric", "value"]] ) # Add a default `score` metric to the end of the dataframe. - df_long = pd.concat([ - df_long, - pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}) - ]) + df_long = pd.concat( + [ + df_long, + pd.DataFrame( + {"metric": ["score"], "value": [df_long.value[df_long.index.max()]]} + ), + ] + ) df_long.to_csv(output_file, index=False) print(f"Converted: {input_file} -> {output_file}") @@ -40,9 +43,16 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") - parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") - parser.add_argument("output", help="Converted Redis benchmark data" + - " (to be consumed by OS Autotune framework).") + parser = argparse.ArgumentParser( + description="Post-process Redis benchmark results." + ) + parser.add_argument( + "input", help="Redis benchmark results (downloaded from a remote VM)." + ) + parser.add_argument( + "output", + help="Converted Redis benchmark data" + + " (to be consumed by OS Autotune framework).", + ) args = parser.parse_args() _main(args.input, args.output) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 41bd162459..649d537558 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -14,8 +14,11 @@ JSON_CONFIG_FILE = "config-boot-time.json" NEW_CFG = "zz-mlos-boot-params.cfg" -with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \ - open(NEW_CFG, 'w', encoding='UTF-8') as fh_config: +with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open( + NEW_CFG, "w", encoding="UTF-8" +) as fh_config: for key, val in json.load(fh_json).items(): - fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$' - f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n') + fh_config.write( + 'GRUB_CMDLINE_LINUX_DEFAULT="$' + f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n' + ) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index de344d61fb..9f130e5c0e 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -14,9 +14,10 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"' fh_config.write(line + "\n") print(line) @@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate GRUB config from tunable parameters JSON.") + description="Generate GRUB config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output shell script to configure GRUB.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index 85a49a1817..e632495061 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -22,7 +22,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: tunables_meta = json.load(fh_meta) with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in tunables_data.items(): + for key, val in tunables_data.items(): meta = tunables_meta.get(key, {}) name_prefix = meta.get("name_prefix", "") line = f'echo "{val}" > {name_prefix}{key}' @@ -33,7 +33,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate a script to update kernel parameters from tunables JSON.") + description="generate a script to update kernel parameters from tunables JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("meta", help="JSON file with tunable parameters metadata.") diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index fa3b63e2e6..672a215aad 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -9,6 +9,6 @@ from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ - 'ConfigSchema', - 'CONFIG_SCHEMA_DIR', + "ConfigSchema", + "CONFIG_SCHEMA_DIR", ] diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 82cbcacce2..181f96e5d6 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -27,9 +27,14 @@ # It is used in `ConfigSchema.validate()` method below. # NOTE: this may cause pytest to fail if it's expecting exceptions # to be raised for invalid configs. -_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION' -_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower() - in {'true', 'y', 'yes', 'on', '1'}) +_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION" +_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in { + "true", + "y", + "yes", + "on", + "1", +} # Note: we separate out the SchemaStore from a class method on ConfigSchema @@ -80,10 +85,12 @@ def _load_registry(cls) -> None: """Also store them in a Registry object for referencing by recent versions of jsonschema.""" if not cls._SCHEMA_STORE: cls._load_schemas() - cls._REGISTRY = Registry().with_resources([ - (url, Resource.from_contents(schema, default_specification=DRAFT202012)) - for url, schema in cls._SCHEMA_STORE.items() - ]) + cls._REGISTRY = Registry().with_resources( + [ + (url, Resource.from_contents(schema, default_specification=DRAFT202012)) + for url, schema in cls._SCHEMA_STORE.items() + ] + ) @property def registry(self) -> Registry: diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 4ccef7817b..26c573b4c0 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -13,7 +13,7 @@ from mlos_bench.os_environ import environ -class DictTemplater: # pylint: disable=too-few-public-methods +class DictTemplater: # pylint: disable=too-few-public-methods """ Simple class to help with nested dictionary $var templating. """ @@ -32,9 +32,12 @@ def __init__(self, source_dict: Dict[str, Any]): # The source/target dictionary to expand. self._dict: Dict[str, Any] = {} - def expand_vars(self, *, - extra_source_dict: Optional[Dict[str, Any]] = None, - use_os_env: bool = False) -> Dict[str, Any]: + def expand_vars( + self, + *, + extra_source_dict: Optional[Dict[str, Any]] = None, + use_os_env: bool = False, + ) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. @@ -55,7 +58,9 @@ def expand_vars(self, *, assert isinstance(self._dict, dict) return self._dict - def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: + def _expand_vars( + self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool + ) -> Any: """ Recursively expand $var strings in the currently operating dictionary. """ @@ -71,10 +76,12 @@ def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], elif isinstance(value, dict): # Note: we use a loop instead of dict comprehension in order to # allow secondary expansion of subsequent values immediately. - for (key, val) in value.items(): + for key, val in value.items(): value[key] = self._expand_vars(val, extra_source_dict, use_os_env) elif isinstance(value, list): - value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] + value = [ + self._expand_vars(val, extra_source_dict, use_os_env) for val in value + ] elif isinstance(value, (int, float, bool)) or value is None: return value else: diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index a1ccadae5f..629e7d9c5f 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -15,12 +15,11 @@ from mlos_bench.environments.status import Status __all__ = [ - 'Status', - - 'Environment', - 'MockEnv', - 'RemoteEnv', - 'LocalEnv', - 'LocalFileShareEnv', - 'CompositeEnv', + "Status", + "Environment", + "MockEnv", + "RemoteEnv", + "LocalEnv", + "LocalFileShareEnv", + "CompositeEnv", ] diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 61fbd69f50..d358f903be 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -48,15 +48,16 @@ class Environment(metaclass=abc.ABCMeta): """ @classmethod - def new(cls, - *, - env_name: str, - class_name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ) -> "Environment": + def new( + cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -94,16 +95,18 @@ def new(cls, config=config, global_config=global_config, tunables=tunables, - service=service + service=service, ) - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -134,34 +137,41 @@ def __init__(self, self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Environment: '%s' Service: %s", name, - self._service.pprint() if self._service else None) + _LOG.debug( + "Environment: '%s' Service: %s", + name, + self._service.pprint() if self._service else None, + ) if tunables is None: - _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) + _LOG.warning( + "No tunables provided for %s. Tunable inheritance across composite environments may be broken.", + name, + ) tunables = TunableGroups() groups = self._expand_groups( config.get("tunable_params", []), - (global_config or {}).get("tunable_params_map", {})) + (global_config or {}).get("tunable_params_map", {}), + ) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) self._tunable_params = tunables.subgroup(groups) # If a parameter comes from the tunables, do not require it in the const_args or globals - req_args = ( - set(config.get("required_args", [])) - - set(self._tunable_params.get_param_values().keys()) + req_args = set(config.get("required_args", [])) - set( + self._tunable_params.get_param_values().keys() + ) + merge_parameters( + dest=self._const_args, source=global_config, required_keys=req_args ) - merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) self._const_args = self._expand_vars(self._const_args, global_config or {}) self._params = self._combine_tunables(self._tunable_params) _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", - name, json.dumps(self.config, indent=2)) + _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: """ @@ -179,8 +189,9 @@ def _validate_json_config(self, config: dict, name: str) -> None: ConfigSchema.ENVIRONMENT.validate(json_config) @staticmethod - def _expand_groups(groups: Iterable[str], - groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: + def _expand_groups( + groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] + ) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -202,7 +213,9 @@ def _expand_groups(groups: Iterable[str], if grp[:1] == "$": tunable_group_name = grp[1:] if tunable_group_name not in groups_exp: - raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") + raise KeyError( + f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}" + ) add_groups = groups_exp[tunable_group_name] res += [add_groups] if isinstance(add_groups, str) else add_groups else: @@ -210,7 +223,9 @@ def _expand_groups(groups: Iterable[str], return res @staticmethod - def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: + def _expand_vars( + params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] + ) -> dict: """ Expand `$var` into actual values of the variables. """ @@ -221,7 +236,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": assert self._service is not None return self._service.config_loader_service - def __enter__(self) -> 'Environment': + def __enter__(self) -> "Environment": """ Enter the environment's benchmarking context. """ @@ -232,9 +247,12 @@ def __enter__(self) -> 'Environment': self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -243,14 +261,20 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], _LOG.debug("Environment END :: %s", self) else: assert ex_type and ex_val - _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) + _LOG.warning( + "Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb) + ) assert self._in_context if self._service_context: try: self._service_context.__exit__(ex_type, ex_val, ex_tb) # pylint: disable=broad-exception-caught except Exception as ex: - _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) + _LOG.error( + "Exception while exiting Service context '%s': %s", + self._service, + ex, + ) ex_throw = ex finally: self._service_context = None @@ -304,7 +328,8 @@ def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ return tunables.get_param_values( group_names=list(self._tunable_params.get_covariant_group_names()), - into_params=self._const_args.copy()) + into_params=self._const_args.copy(), + ) @property def tunable_params(self) -> TunableGroups: @@ -331,7 +356,9 @@ def parameters(self) -> Dict[str, TunableValue]: """ return self._params - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Set up a new benchmark environment, if necessary. This method must be idempotent, i.e., calling it several times in a row should be @@ -364,10 +391,15 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - # (Derived classes still have to check `self._tunable_params.is_updated()`). is_updated = self._tunable_params.is_updated() if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { - name: self._tunable_params.is_updated([name]) - for name in self._tunable_params.get_covariant_group_names() - }) + _LOG.debug( + "Env '%s': Tunable groups reset = %s :: %s", + self, + is_updated, + { + name: self._tunable_params.is_updated([name]) + for name in self._tunable_params.get_covariant_group_names() + }, + ) else: _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index a71b8ab9be..4b5e2755cf 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -27,13 +27,15 @@ class CompositeEnv(Environment): Composite benchmark environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -53,8 +55,13 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # By default, the Environment includes only the tunables explicitly specified # in the "tunable_params" section of the config. `CompositeEnv`, however, must @@ -70,20 +77,28 @@ def __init__(self, # each CompositeEnv gets a copy of the original global config and adjusts it with # the `const_args` specific to it. global_config = (global_config or {}).copy() - for (key, val) in self._const_args.items(): + for key, val in self._const_args.items(): global_config.setdefault(key, val) for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, tunables, global_config, self._const_args, self._service): + child_config_file, + tunables, + global_config, + self._const_args, + self._service, + ): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service) + child_config, tunables, global_config, self._const_args, self._service + ) self._add_child(env, tunables) - _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) + _LOG.debug( + "Build composite environment '%s' END: %s", self, self._tunable_params + ) if not self._children: raise ValueError("At least one child environment must be present") @@ -92,16 +107,21 @@ def __enter__(self) -> Environment: self._child_contexts = [env.__enter__() for env in self._children] return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: ex_throw = None for env in reversed(self._children): try: env.__exit__(ex_type, ex_val, ex_tb) # pylint: disable=broad-exception-caught except Exception as ex: - _LOG.error("Exception while exiting child environment '%s': %s", env, ex) + _LOG.error( + "Exception while exiting child environment '%s': %s", env, ex + ) ex_throw = ex self._child_contexts = [] super().__exit__(ex_type, ex_val, ex_tb) @@ -132,8 +152,11 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: pretty : str Pretty-printed environment configuration. """ - return super().pprint(indent, level) + '\n' + '\n'.join( - child.pprint(indent, level + 1) for child in self._children) + return ( + super().pprint(indent, level) + + "\n" + + "\n".join(child.pprint(indent, level + 1) for child in self._children) + ) def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ @@ -145,7 +168,9 @@ def _add_child(self, env: Environment, tunables: TunableGroups) -> None: self._tunable_params.merge(env.tunable_params) tunables.merge(env.tunable_params) - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Set up the children environments. @@ -165,7 +190,9 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - """ assert self._in_context self._is_ready = super().setup(tunables, global_config) and all( - env_context.setup(tunables, global_config) for env_context in self._child_contexts) + env_context.setup(tunables, global_config) + for env_context in self._child_contexts + ) return self._is_ready def teardown(self) -> None: @@ -202,7 +229,9 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: for env_context in self._child_contexts: _LOG.debug("Child env. run: %s", env_context) (status, timestamp, metrics) = env_context.run() - _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics) + _LOG.debug( + "Child env. run results: %s :: %s %s", env_context, status, metrics + ) if not status.is_good(): _LOG.info("Run failed: %s :: %s", self, status) return (status, timestamp, None) diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index 0cdd8349b4..a99eefea19 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv __all__ = [ - 'LocalEnv', - 'LocalFileShareEnv', + "LocalEnv", + "LocalFileShareEnv", ] diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index da20f5c961..a78898d90b 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -36,13 +36,15 @@ class LocalEnv(ScriptEnv): Scheduler-side Environment that runs scripts locally. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for local execution. @@ -65,11 +67,17 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service self._temp_dir: Optional[str] = None @@ -79,17 +87,24 @@ def __init__(self, self._dump_meta_file: Optional[str] = self.config.get("dump_meta_file") self._read_results_file: Optional[str] = self.config.get("read_results_file") - self._read_telemetry_file: Optional[str] = self.config.get("read_telemetry_file") + self._read_telemetry_file: Optional[str] = self.config.get( + "read_telemetry_file" + ) def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None - self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir")) + self._temp_dir_context = self._local_exec_service.temp_dir_context( + self.config.get("temp_dir") + ) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -99,7 +114,9 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], self._temp_dir_context = None return super().__exit__(ex_type, ex_val, ex_tb) - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Check if the environment is ready and set up the application and benchmarks, if necessary. @@ -137,13 +154,19 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - fname = path_join(self._temp_dir, self._dump_meta_file) _LOG.debug("Dump tunables metadata to file: %s", fname) with open(fname, "wt", encoding="utf-8") as fh_meta: - json.dump({ - tunable.name: tunable.meta - for (tunable, _group) in self._tunable_params if tunable.meta - }, fh_meta) + json.dump( + { + tunable.name: tunable.meta + for (tunable, _group) in self._tunable_params + if tunable.meta + }, + fh_meta, + ) if self._script_setup: - (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir) + (return_code, _output) = self._local_exec( + self._script_setup, self._temp_dir + ) self._is_ready = bool(return_code == 0) else: self._is_ready = True @@ -180,18 +203,26 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: _LOG.debug("Not reading the data at: %s", self) return (Status.SUCCEEDED, timestamp, stdout_data) - data = self._normalize_columns(pandas.read_csv( - self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir]), - index_col=False, - )) + data = self._normalize_columns( + pandas.read_csv( + self._config_loader_service.resolve_path( + self._read_results_file, extra_paths=[self._temp_dir] + ), + index_col=False, + ) + ) _LOG.debug("Read data:\n%s", data) if list(data.columns) == ["metric", "value"]: - _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data)) - data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) + _LOG.info( + "Local results have (metric,value) header and %d rows: assume long format", + len(data), + ) + data = pandas.DataFrame( + [data.value.to_list()], columns=data.metric.to_list() + ) # Try to convert string metrics to numbers. - data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive) + data = data.apply(pandas.to_numeric, errors="coerce").fillna(data) # type: ignore[assignment] # (false positive) elif len(data) == 1: _LOG.info("Local results have 1 row: assume wide format") else: @@ -209,8 +240,8 @@ def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. - if sys.platform == 'win32': - data.rename(str.rstrip, axis='columns', inplace=True) + if sys.platform == "win32": + data.rename(str.rstrip, axis="columns", inplace=True) return data def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: @@ -222,36 +253,45 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir]) + self._read_telemetry_file, extra_paths=[self._temp_dir] + ) # TODO: Use the timestamp of the CSV file as our status timestamp? # FIXME: We should not be assuming that the only output file type is a CSV. - data = self._normalize_columns( - pandas.read_csv(fname, index_col=False)) + data = self._normalize_columns(pandas.read_csv(fname, index_col=False)) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") expected_col_names = ["timestamp", "metric", "value"] if len(data.columns) != len(expected_col_names): - raise ValueError(f'Telemetry data must have columns {expected_col_names}') + raise ValueError( + f"Telemetry data must have columns {expected_col_names}" + ) if list(data.columns) != expected_col_names: # Assume no header - this is ok for telemetry data. - data = pandas.read_csv( - fname, index_col=False, names=expected_col_names) + data = pandas.read_csv(fname, index_col=False, names=expected_col_names) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") except FileNotFoundError as ex: - _LOG.warning("Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex) + _LOG.warning( + "Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex + ) return (status, timestamp, []) _LOG.debug("Read telemetry data:\n%s", data) col_dtypes: Mapping[int, Type] = {0: datetime} - return (status, timestamp, [ - (pandas.Timestamp(ts).to_pydatetime(), metric, value) - for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) - ]) + return ( + status, + timestamp, + [ + (pandas.Timestamp(ts).to_pydatetime(), metric, value) + for (ts, metric, value) in data.to_records( + index=False, column_dtypes=col_dtypes + ) + ], + ) def teardown(self) -> None: """ @@ -263,7 +303,9 @@ def teardown(self) -> None: _LOG.info("Local teardown complete: %s :: %s", self, return_code) super().teardown() - def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple[int, dict]: + def _local_exec( + self, script: Iterable[str], cwd: Optional[str] = None + ) -> Tuple[int, dict]: """ Execute a script locally in the scheduler environment. @@ -283,7 +325,10 @@ def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd) + script, env=env_params, cwd=cwd + ) if return_code != 0: - _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) + _LOG.warning( + "ERROR: Local script returns code %d stderr:\n%s", return_code, stderr + ) return (return_code, {"stdout": stdout, "stderr": stderr}) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 174afd387c..636c7cb9a5 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -29,13 +29,15 @@ class LocalFileShareEnv(LocalEnv): and uploads/downloads data to the shared file storage. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new application environment with a given config. @@ -59,14 +61,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service - assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ - "LocalEnv requires a service that supports file upload/download operations" + assert self._service is not None and isinstance( + self._service, SupportsFileShareOps + ), "LocalEnv requires a service that supports file upload/download operations" self._file_share_service: SupportsFileShareOps = self._service self._upload = self._template_from_to("upload") @@ -78,13 +88,14 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: of string.Template objects so that we can plug in self._params into it later. """ return [ - (Template(d['from']), Template(d['to'])) + (Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, []) ] @staticmethod - def _expand(from_to: Iterable[Tuple[Template, Template]], - params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: + def _expand( + from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] + ) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. Return a generator of (str, str) pairs of paths. @@ -94,7 +105,9 @@ def _expand(from_to: Iterable[Tuple[Template, Template]], for (path_from, path_to) in from_to ) - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Run setup scripts locally and upload the scripts and data to the shared storage. @@ -119,9 +132,14 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._upload, params): - self._file_share_service.upload(self._params, self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir]), path_to) + for path_from, path_to in self._expand(self._upload, params): + self._file_share_service.upload( + self._params, + self._config_loader_service.resolve_path( + path_from, extra_paths=[self._temp_dir] + ), + path_to, + ) return self._is_ready def _download_files(self, ignore_missing: bool = False) -> None: @@ -137,11 +155,15 @@ def _download_files(self, ignore_missing: bool = False) -> None: assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._download, params): + for path_from, path_to in self._expand(self._download, params): try: - self._file_share_service.download(self._params, - path_from, self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir])) + self._file_share_service.download( + self._params, + path_from, + self._config_loader_service.resolve_path( + path_to, extra_paths=[self._temp_dir] + ), + ) except FileNotFoundError as ex: _LOG.warning("Cannot download: %s", path_from) if not ignore_missing: diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index cc47b95500..c9d6ac7ed3 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -29,13 +29,15 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment that produces mock benchmark data. @@ -55,8 +57,13 @@ def __init__(self, service: Service An optional service object. Not used by this class. """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) seed = int(self.config.get("mock_env_seed", -1)) self._random = random.Random(seed or None) if seed >= 0 else None self._range = self.config.get("mock_env_range") @@ -81,9 +88,14 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return result # Simple convex function of all tunable parameters. - score = numpy.mean(numpy.square([ - self._normalized(tunable) for (tunable, _group) in self._tunable_params - ])) + score = numpy.mean( + numpy.square( + [ + self._normalized(tunable) + for (tunable, _group) in self._tunable_params + ] + ) + ) # Add noise and shift the benchmark value from [0, 1] to a given range. noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 @@ -91,7 +103,11 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: if self._range: score = self._range[0] + score * (self._range[1] - self._range[0]) - return (Status.SUCCEEDED, timestamp, {metric: score for metric in self._metrics}) + return ( + Status.SUCCEEDED, + timestamp, + {metric: score for metric in self._metrics}, + ) @staticmethod def _normalized(tunable: Tunable) -> float: @@ -101,11 +117,13 @@ def _normalized(tunable: Tunable) -> float: """ val = None if tunable.is_categorical: - val = (tunable.categories.index(tunable.category) / - float(len(tunable.categories) - 1)) + val = tunable.categories.index(tunable.category) / float( + len(tunable.categories) - 1 + ) elif tunable.is_numerical: - val = ((tunable.numerical_value - tunable.range[0]) / - float(tunable.range[1] - tunable.range[0])) + val = (tunable.numerical_value - tunable.range[0]) / float( + tunable.range[1] - tunable.range[0] + ) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index f07575ac86..be18bff2fe 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -14,10 +14,10 @@ from mlos_bench.environments.remote.vm_env import VMEnv __all__ = [ - 'HostEnv', - 'NetworkEnv', - 'OSEnv', - 'RemoteEnv', - 'SaaSEnv', - 'VMEnv', + "HostEnv", + "NetworkEnv", + "OSEnv", + "RemoteEnv", + "SaaSEnv", + "VMEnv", ] diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index 05896c9e60..e754fce417 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -22,13 +22,15 @@ class HostEnv(Environment): Remote host environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for host operations. @@ -49,13 +51,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM/host, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \ - "HostEnv requires a service that supports host provisioning operations" + assert self._service is not None and isinstance( + self._service, SupportsHostProvisioning + ), "HostEnv requires a service that supports host provisioning operations" self._host_service: SupportsHostProvisioning = self._service - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Check if host is ready. (Re)provision and start it, if necessary. @@ -93,7 +104,9 @@ def teardown(self) -> None: _LOG.info("Host tear down: %s", self) (status, params) = self._host_service.deprovision_host(self._params) if status.is_pending(): - (status, _) = self._host_service.wait_host_deployment(params, is_setup=False) + (status, _) = self._host_service.wait_host_deployment( + params, is_setup=False + ) super().teardown() _LOG.debug("Final status of Host deprovisioning: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index 552f1729d9..ba06e7ad5c 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -27,13 +27,15 @@ class NetworkEnv(Environment): but no real tuning is expected for it ... yet. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for network operations. @@ -54,17 +56,26 @@ def __init__(self, An optional service object (e.g., providing methods to deploy a network, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # Virtual networks can be used for more than one experiment, so by default # we don't attempt to deprovision them. self._deprovision_on_teardown = config.get("deprovision_on_teardown", False) - assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \ - "NetworkEnv requires a service that supports network provisioning" + assert self._service is not None and isinstance( + self._service, SupportsNetworkProvisioning + ), "NetworkEnv requires a service that supports network provisioning" self._network_service: SupportsNetworkProvisioning = self._service - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Check if network is ready. Provision, if necessary. @@ -91,7 +102,9 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - (status, params) = self._network_service.provision_network(self._params) if status.is_pending(): - (status, _) = self._network_service.wait_network_deployment(params, is_setup=True) + (status, _) = self._network_service.wait_network_deployment( + params, is_setup=True + ) self._is_ready = status.is_succeeded() return self._is_ready @@ -105,9 +118,13 @@ def teardown(self) -> None: return # Else _LOG.info("Network tear down: %s", self) - (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True) + (status, params) = self._network_service.deprovision_network( + self._params, ignore_errors=True + ) if status.is_pending(): - (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) + (status, _) = self._network_service.wait_network_deployment( + params, is_setup=False + ) super().teardown() _LOG.debug("Final status of Network deprovisioning: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index ef733c77c2..398c3b65db 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -24,13 +24,15 @@ class OSEnv(Environment): OS Level Environment for a host. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -53,17 +55,27 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsOSOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsOSOps + ), "RemoteEnv requires a service that supports host operations" self._os_service: SupportsOSOps = self._service - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Check if the host is up and running; boot it, if necessary. diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index cf38a57b01..112d83c4f1 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -61,21 +63,30 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a Host, VM, OS, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._wait_boot = self.config.get("wait_boot", False) - assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ - "RemoteEnv requires a service that supports remote execution operations" + assert self._service is not None and isinstance( + self._service, SupportsRemoteExec + ), "RemoteEnv requires a service that supports remote execution operations" self._remote_exec_service: SupportsRemoteExec = self._service if self._wait_boot: - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Check if the environment is ready and set up the application and benchmarks on a remote host. @@ -152,7 +163,9 @@ def teardown(self) -> None: _LOG.info("Remote teardown complete: %s :: %s", self, status) super().teardown() - def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optional[dict]]: + def _remote_exec( + self, script: Iterable[str] + ) -> Tuple[Status, datetime, Optional[dict]]: """ Run a script on the remote host. @@ -170,7 +183,8 @@ def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optiona env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params) + script, config=self._params, env_params=env_params + ) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: (status, output) = self._remote_exec_service.get_remote_exec_results(output) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index b661bfad7e..8885bafc05 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -23,13 +23,15 @@ class SaaSEnv(Environment): Cloud-based (configurable) SaaS environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for (configurable) cloud-based SaaS instance. @@ -50,18 +52,27 @@ def __init__(self, An optional service object (e.g., providing methods to configure the remote service). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \ - "SaaSEnv requires a service that supports remote host configuration API" + assert self._service is not None and isinstance( + self._service, SupportsRemoteConfig + ), "SaaSEnv requires a service that supports remote host configuration API" self._config_service: SupportsRemoteConfig = self._service - def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: + def setup( + self, tunables: TunableGroups, global_config: Optional[dict] = None + ) -> bool: """ Update the configuration of a remote SaaS instance. @@ -84,7 +95,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values()) + self._params, self._tunable_params.get_param_values() + ) if not status.is_succeeded(): return False @@ -93,7 +105,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False # Azure Flex DB instances currently require a VM reboot after reconfiguration. - if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'): + if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"): _LOG.info("Restarting: %s", self) (status, params) = self._host_service.restart_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 129ac21a0f..d2e4992700 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -27,13 +27,15 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for script execution. @@ -63,19 +65,29 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) - self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) + self._shell_env_params_rename: Dict[str, str] = self.config.get( + "shell_env_params_rename", {} + ) results_stdout_pattern = self.config.get("results_stdout_pattern") - self._results_stdout_pattern: Optional[re.Pattern[str]] = \ - re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None + self._results_stdout_pattern: Optional[re.Pattern[str]] = ( + re.compile(results_stdout_pattern, flags=re.MULTILINE) + if results_stdout_pattern + else None + ) def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: """ @@ -115,5 +127,10 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: """ if not self._results_stdout_pattern: return {} - _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) - return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)} + _LOG.debug( + "Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout + ) + return { + key: try_parse_val(val) + for (key, val) in self._results_stdout_pattern.findall(stdout) + } diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index 4555ab7f50..f39f96c5ad 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -20,7 +20,7 @@ else: from typing_extensions import TypeAlias -CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name +CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name if sys.version_info >= (3, 9): FutureReturnType: TypeAlias = Future[CoroReturnType] else: @@ -66,10 +66,14 @@ def enter(self) -> None: assert self._event_loop_thread_refcnt == 0 if self._event_loop is None: if sys.platform == "win32": - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy() + ) self._event_loop = asyncio.new_event_loop() assert not self._event_loop.is_running() - self._event_loop_thread = Thread(target=self._run_event_loop, daemon=True) + self._event_loop_thread = Thread( + target=self._run_event_loop, daemon=True + ) self._event_loop_thread.start() self._event_loop_thread_refcnt += 1 @@ -90,7 +94,9 @@ def exit(self) -> None: raise RuntimeError("Failed to stop event loop thread.") self._event_loop_thread = None - def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: + def run_coroutine( + self, coro: Coroutine[Any, Any, CoroReturnType] + ) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread and returns a Future that can be used to wait for the result. diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c8e48dab69..c20ef557d0 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -32,7 +32,9 @@ from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO -_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' +_LOG_FORMAT = ( + "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" +) logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) _LOG = logging.getLogger(__name__) @@ -44,7 +46,9 @@ class Launcher: Command line launcher for mlos_bench and mlos_core. """ - def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): + def __init__( + self, description: str, long_text: str = "", argv: Optional[List[str]] = None + ): # pylint: disable=too-many-statements _LOG.info("Launch: %s", description) epilog = """ @@ -54,8 +58,9 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st For additional details, please see the website or the README.md files in the source tree: """ - parser = argparse.ArgumentParser(description=f"{description} : {long_text}", - epilog=epilog) + parser = argparse.ArgumentParser( + description=f"{description} : {long_text}", epilog=epilog + ) (args, args_rest) = self._parse_args(parser, argv) # Bootstrap config loader: command line takes priority. @@ -96,38 +101,50 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: - self.global_config['experiment_id'] = args.experiment_id + self.global_config["experiment_id"] = args.experiment_id # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line if args.trial_config_repeat_count: - self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count + self.global_config["trial_config_repeat_count"] = ( + args.trial_config_repeat_count + ) # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. - self.global_config.setdefault('trial_id', 1) + self.global_config.setdefault("trial_id", 1) - self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) + self.global_config = DictTemplater(self.global_config).expand_vars( + use_os_env=True + ) assert isinstance(self.global_config, dict) # --service cli args should override the config file values. service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) - self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) + self._parent_service = self._parent_service.load_services( + service_files, self.global_config, self._parent_service + ) env_path = args.environment or config.get("environment") if not env_path: _LOG.error("No environment config specified.") - parser.error("At least the Environment config must be specified." + - " Run `mlos_bench --help` and consult `README.md` for more info.") + parser.error( + "At least the Environment config must be specified." + + " Run `mlos_bench --help` and consult `README.md` for more info." + ) self.root_env_config = self._config_loader.resolve_path(env_path) self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) + self.root_env_config, + TunableGroups(), + self.global_config, + service=self._parent_service, + ) _LOG.info("Init environment: %s", self.environment) # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer self.tunables = self._init_tunable_values( args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, - config.get("tunable_values", []) + (args.tunable_values or []) + config.get("tunable_values", []) + (args.tunable_values or []), ) _LOG.info("Init tunables: %s", self.tunables) @@ -137,7 +154,11 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st self.storage = self._load_storage(args.storage or config.get("storage")) _LOG.info("Init storage: %s", self.storage) - self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) + self.teardown: bool = ( + bool(args.teardown) + if args.teardown is not None + else bool(config.get("teardown", True)) + ) self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) _LOG.info("Init scheduler: %s", self.scheduler) @@ -156,87 +177,146 @@ def service(self) -> Service: return self._parent_service @staticmethod - def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: + def _parse_args( + parser: argparse.ArgumentParser, argv: Optional[List[str]] + ) -> Tuple[argparse.Namespace, List[str]]: """ Parse the command line arguments. """ parser.add_argument( - '--config', required=False, - help='Main JSON5 configuration file. Its keys are the same as the' + - ' command line options and can be overridden by the latter.\n' + - '\n' + - ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + - ' for additional config examples for this and other arguments.') + "--config", + required=False, + help="Main JSON5 configuration file. Its keys are the same as the" + + " command line options and can be overridden by the latter.\n" + + "\n" + + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " + + " for additional config examples for this and other arguments.", + ) parser.add_argument( - '--log_file', '--log-file', required=False, - help='Path to the log file. Use stdout if omitted.') + "--log_file", + "--log-file", + required=False, + help="Path to the log file. Use stdout if omitted.", + ) parser.add_argument( - '--log_level', '--log-level', required=False, type=str, - help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + - ' Set to DEBUG for debug, WARNING for warnings only.') + "--log_level", + "--log-level", + required=False, + type=str, + help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}." + + " Set to DEBUG for debug, WARNING for warnings only.", + ) parser.add_argument( - '--config_path', '--config-path', '--config-paths', '--config_paths', - nargs="+", action='extend', required=False, - help='One or more locations of JSON config files.') + "--config_path", + "--config-path", + "--config-paths", + "--config_paths", + nargs="+", + action="extend", + required=False, + help="One or more locations of JSON config files.", + ) parser.add_argument( - '--service', '--services', - nargs='+', action='extend', required=False, - help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') + "--service", + "--services", + nargs="+", + action="extend", + required=False, + help="Path to JSON file with the configuration of the service(s) for environment(s) to use.", + ) parser.add_argument( - '--environment', required=False, - help='Path to JSON file with the configuration of the benchmarking environment(s).') + "--environment", + required=False, + help="Path to JSON file with the configuration of the benchmarking environment(s).", + ) parser.add_argument( - '--optimizer', required=False, - help='Path to the optimizer configuration file. If omitted, run' + - ' a single trial with default (or specified in --tunable_values).') + "--optimizer", + required=False, + help="Path to the optimizer configuration file. If omitted, run" + + " a single trial with default (or specified in --tunable_values).", + ) parser.add_argument( - '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, - help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') + "--trial_config_repeat_count", + "--trial-config-repeat-count", + required=False, + type=int, + help="Number of times to repeat each config. Default is 1 trial per config, though more may be advised.", + ) parser.add_argument( - '--scheduler', required=False, - help='Path to the scheduler configuration file. By default, use' + - ' a single worker synchronous scheduler.') + "--scheduler", + required=False, + help="Path to the scheduler configuration file. By default, use" + + " a single worker synchronous scheduler.", + ) parser.add_argument( - '--storage', required=False, - help='Path to the storage configuration file.' + - ' If omitted, use the ephemeral in-memory SQL storage.') + "--storage", + required=False, + help="Path to the storage configuration file." + + " If omitted, use the ephemeral in-memory SQL storage.", + ) parser.add_argument( - '--random_init', '--random-init', required=False, default=False, - dest='random_init', action='store_true', - help='Initialize tunables with random values. (Before applying --tunable_values).') + "--random_init", + "--random-init", + required=False, + default=False, + dest="random_init", + action="store_true", + help="Initialize tunables with random values. (Before applying --tunable_values).", + ) parser.add_argument( - '--random_seed', '--random-seed', required=False, type=int, - help='Seed to use with --random_init') + "--random_seed", + "--random-seed", + required=False, + type=int, + help="Seed to use with --random_init", + ) parser.add_argument( - '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain values of the tunable' + - ' parameters. This can be used for a single trial (when no --optimizer' + - ' is specified) or as default values for the first run in optimization.') + "--tunable_values", + "--tunable-values", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain values of the tunable" + + " parameters. This can be used for a single trial (when no --optimizer" + + " is specified) or as default values for the first run in optimization.", + ) parser.add_argument( - '--globals', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain additional' + - ' [private] parameters of the benchmarking environment.') + "--globals", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain additional" + + " [private] parameters of the benchmarking environment.", + ) parser.add_argument( - '--no_teardown', '--no-teardown', required=False, default=None, - dest='teardown', action='store_false', - help='Disable teardown of the environment after the benchmark.') + "--no_teardown", + "--no-teardown", + required=False, + default=None, + dest="teardown", + action="store_false", + help="Disable teardown of the environment after the benchmark.", + ) parser.add_argument( - '--experiment_id', '--experiment-id', required=False, default=None, + "--experiment_id", + "--experiment-id", + required=False, + default=None, help=""" Experiment ID to use for the benchmark. If omitted, the value from the --cli config or --globals is used. @@ -246,7 +326,7 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T changes are made to config files, scripts, versions, etc. This is left as a manual operation as detection of what is "incompatible" is not easily automatable across systems. - """ + """, ) # By default we use the command line arguments, but allow the caller to @@ -288,16 +368,18 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: _LOG.debug("Parsed config: %s", config) return config - def _load_config(self, - args_globals: Iterable[str], - config_path: Iterable[str], - args_rest: Iterable[str], - global_config: Dict[str, Any]) -> Dict[str, Any]: + def _load_config( + self, + args_globals: Iterable[str], + config_path: Iterable[str], + args_rest: Iterable[str], + global_config: Dict[str, Any], + ) -> Dict[str, Any]: """ Get key/value pairs of the global configuration parameters from the specified config files (if any) and command line arguments. """ - for config_file in (args_globals or []): + for config_file in args_globals or []: conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) assert isinstance(conf, dict) global_config.update(conf) @@ -306,8 +388,9 @@ def _load_config(self, global_config["config_path"] = config_path return global_config - def _init_tunable_values(self, random_init: bool, seed: Optional[int], - args_tunables: Optional[str]) -> TunableGroups: + def _init_tunable_values( + self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] + ) -> TunableGroups: """ Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. @@ -317,13 +400,17 @@ def _init_tunable_values(self, random_init: bool, seed: Optional[int], if random_init: tunables = MockOptimizer( - tunables=tunables, service=None, - config={"start_with_defaults": False, "seed": seed}).suggest() + tunables=tunables, + service=None, + config={"start_with_defaults": False, "seed": seed}, + ).suggest() _LOG.debug("Init tunables: random = %s", tunables) if args_tunables is not None: for data_file in args_tunables: - values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES) + values = self._config_loader.load_config( + data_file, ConfigSchema.TUNABLE_VALUES + ) assert isinstance(values, Dict) tunables.assign(values) _LOG.debug("Init tunables: load %s = %s", data_file, tunables) @@ -339,15 +426,24 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: if args_optimizer is None: # global_config may contain additional properties, so we need to # strip those out before instantiating the basic oneshot optimizer. - config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} + config = { + key: val + for key, val in self.global_config.items() + if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS + } return OneShotOptimizer( - self.tunables, config=config, service=self._parent_service) - class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) + self.tunables, config=config, service=self._parent_service + ) + class_config = self._config_loader.load_config( + args_optimizer, ConfigSchema.OPTIMIZER + ) assert isinstance(class_config, Dict) - optimizer = self._config_loader.build_optimizer(tunables=self.tunables, - service=self._parent_service, - config=class_config, - global_config=self.global_config) + optimizer = self._config_loader.build_optimizer( + tunables=self.tunables, + service=self._parent_service, + config=class_config, + global_config=self.global_config, + ) return optimizer def _load_storage(self, args_storage: Optional[str]) -> Storage: @@ -359,17 +455,24 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: if args_storage is None: # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.storage import SqlStorage - return SqlStorage(service=self._parent_service, - config={ - "drivername": "sqlite", - "database": ":memory:", - "lazy_schema_create": True, - }) - class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) + + return SqlStorage( + service=self._parent_service, + config={ + "drivername": "sqlite", + "database": ":memory:", + "lazy_schema_create": True, + }, + ) + class_config = self._config_loader.load_config( + args_storage, ConfigSchema.STORAGE + ) assert isinstance(class_config, Dict) - storage = self._config_loader.build_storage(service=self._parent_service, - config=class_config, - global_config=self.global_config) + storage = self._config_loader.build_storage( + service=self._parent_service, + config=class_config, + global_config=self.global_config, + ) return storage def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: @@ -384,6 +487,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: if args_scheduler is None: # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.sync_scheduler import SyncScheduler + return SyncScheduler( # All config values can be overridden from global config config={ @@ -399,7 +503,9 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: storage=self.storage, root_env_config=self.root_env_config, ) - class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER) + class_config = self._config_loader.load_config( + args_scheduler, ConfigSchema.SCHEDULER + ) assert isinstance(class_config, Dict) return self._config_loader.build_scheduler( config=class_config, diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index f10fa3c82e..a61b55d440 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -12,8 +12,8 @@ from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer __all__ = [ - 'Optimizer', - 'MockOptimizer', - 'OneShotOptimizer', - 'MlosCoreOptimizer', + "Optimizer", + "MockOptimizer", + "OneShotOptimizer", + "MlosCoreOptimizer", ] diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index b9df1db1b7..b67ebbfbd9 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -26,7 +26,7 @@ _LOG = logging.getLogger(__name__) -class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes +class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ An abstract interface between the benchmarking framework and mlos_core optimizers. """ @@ -39,11 +39,13 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr "start_with_defaults", } - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new optimizer for the given configuration space defined by the tunables. @@ -67,25 +69,30 @@ def __init__(self, self._seed = int(config.get("seed", 42)) self._in_context = False - experiment_id = self._global_config.get('experiment_id') + experiment_id = self._global_config.get("experiment_id") self.experiment_id = str(experiment_id).strip() if experiment_id else None self._iter = 0 # If False, use the optimizer to suggest the initial configuration; # if True (default), use the already initialized values for the first iteration. self._start_with_defaults: bool = bool( - strtobool(str(self._config.pop('start_with_defaults', True)))) - self._max_iter = int(self._config.pop('max_suggestions', 100)) + strtobool(str(self._config.pop("start_with_defaults", True))) + ) + self._max_iter = int(self._config.pop("max_suggestions", 100)) - opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'}) + opt_targets: Dict[str, str] = self._config.pop( + "optimization_targets", {"score": "min"} + ) self._opt_targets: Dict[str, Literal[1, -1]] = {} - for (opt_target, opt_dir) in opt_targets.items(): + for opt_target, opt_dir in opt_targets.items(): if opt_dir == "min": self._opt_targets[opt_target] = 1 elif opt_dir == "max": self._opt_targets[opt_target] = -1 else: - raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}") + raise ValueError( + f"Invalid optimization direction: {opt_dir} for {opt_target}" + ) def _validate_json_config(self, config: dict) -> None: """ @@ -107,7 +114,7 @@ def __repr__(self) -> str: ) return f"{self.name}({opt_targets},config={self._config})" - def __enter__(self) -> 'Optimizer': + def __enter__(self) -> "Optimizer": """ Enter the optimizer's context. """ @@ -116,9 +123,12 @@ def __enter__(self) -> 'Optimizer': self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the optimizer. """ @@ -190,7 +200,9 @@ def config_space(self) -> ConfigurationSpace: The ConfigSpace representation of the tunable parameters. """ if self._config_space is None: - self._config_space = tunable_groups_to_configspace(self._tunables, self._seed) + self._config_space = tunable_groups_to_configspace( + self._tunables, self._seed + ) _LOG.debug("ConfigSpace: %s", self._config_space) return self._config_space @@ -203,7 +215,7 @@ def name(self) -> str: return self.__class__.__name__ @property - def targets(self) -> Dict[str, Literal['min', 'max']]: + def targets(self) -> Dict[str, Literal["min", "max"]]: """ A dictionary of {target: direction} of optimization targets. """ @@ -220,10 +232,12 @@ def supports_preload(self) -> bool: return True @abstractmethod - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. @@ -241,8 +255,12 @@ def bulk_register(self, is_not_empty : bool True if there is data to register, false otherwise. """ - _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values", - len(configs or []), len(scores or []), len(status or [])) + _LOG.info( + "Update the optimizer with: %d configs, %d scores, %d status values", + len(configs or []), + len(scores or []), + len(status or []), + ) if len(configs or []) != len(scores or []): raise ValueError("Numbers of configs and scores do not match.") if status is not None and len(configs or []) != len(status or []): @@ -271,8 +289,12 @@ def suggest(self) -> TunableGroups: return self._tunables.copy() @abstractmethod - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: """ Register the observation for the given configuration. @@ -293,15 +315,22 @@ def register(self, tunables: TunableGroups, status: Status, Benchmark scores extracted (and possibly transformed) from the dataframe that's being MINIMIZED. """ - _LOG.info("Iteration %d :: Register: %s = %s score: %s", - self._iter, tunables, status, score) + _LOG.info( + "Iteration %d :: Register: %s = %s score: %s", + self._iter, + tunables, + status, + score, + ) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) - def _get_scores(self, status: Status, - scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] - ) -> Optional[Dict[str, float]]: + def _get_scores( + self, + status: Status, + scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]], + ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -330,7 +359,7 @@ def _get_scores(self, status: Status, assert scores is not None target_metrics: Dict[str, float] = {} - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): val = scores[opt_target] assert val is not None target_metrics[opt_target] = float(val) * opt_dir @@ -345,7 +374,9 @@ def not_converged(self) -> bool: return self._iter < self._max_iter @abstractmethod - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 62341c613d..6dc24c01d9 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -48,7 +48,8 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 +) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, wrapped in a ConfigurationSpace for composability. @@ -71,14 +72,19 @@ def _tunable_to_configspace( meta = {"group": group_name, "cost": cost} # {"scaling": ""} if tunable.type == "categorical": - return ConfigurationSpace({ - tunable.name: CategoricalHyperparameter( - name=tunable.name, - choices=tunable.categories, - weights=_normalize_weights(tunable.weights) if tunable.weights else None, - default_value=tunable.default, - meta=meta) - }) + return ConfigurationSpace( + { + tunable.name: CategoricalHyperparameter( + name=tunable.name, + choices=tunable.categories, + weights=( + _normalize_weights(tunable.weights) if tunable.weights else None + ), + default_value=tunable.default, + meta=meta, + ) + } + ) distribution: Union[Uniform, Normal, Beta, None] = None if tunable.distribution == "uniform": @@ -86,12 +92,12 @@ def _tunable_to_configspace( elif tunable.distribution == "normal": distribution = Normal( mu=tunable.distribution_params["mu"], - sigma=tunable.distribution_params["sigma"] + sigma=tunable.distribution_params["sigma"], ) elif tunable.distribution == "beta": distribution = Beta( alpha=tunable.distribution_params["alpha"], - beta=tunable.distribution_params["beta"] + beta=tunable.distribution_params["beta"], ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -103,22 +109,26 @@ def _tunable_to_configspace( log=bool(tunable.is_log), q=nullable(int, tunable.quantization), distribution=distribution, - default=(int(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + int(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) elif tunable.type == "float": range_hp = Float( name=tunable.name, bounds=tunable.range, log=bool(tunable.is_log), - q=tunable.quantization, # type: ignore[arg-type] + q=tunable.quantization, # type: ignore[arg-type] distribution=distribution, # type: ignore[arg-type] - default=(float(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + float(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) else: raise TypeError(f"Invalid Parameter Type: {tunable.type}") @@ -131,36 +141,50 @@ def _tunable_to_configspace( switch_weights = [0.5, 0.5] # FLAML requires uniform weights. if tunable.weights and tunable.range_weight is not None: special_weights = _normalize_weights(tunable.weights) - switch_weights = _normalize_weights([sum(tunable.weights), tunable.range_weight]) + switch_weights = _normalize_weights( + [sum(tunable.weights), tunable.range_weight] + ) # Create three hyperparameters: one for regular values, # one for special values, and one to choose between the two. (special_name, type_name) = special_param_names(tunable.name) - conf_space = ConfigurationSpace({ - tunable.name: range_hp, - special_name: CategoricalHyperparameter( - name=special_name, - choices=tunable.special, - weights=special_weights, - default_value=tunable.default if tunable.default in tunable.special else None, - meta=meta - ), - type_name: CategoricalHyperparameter( - name=type_name, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], - weights=switch_weights, - default_value=TunableValueKind.SPECIAL, - ), - }) - conf_space.add_condition(EqualsCondition( - conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)) - conf_space.add_condition(EqualsCondition( - conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)) + conf_space = ConfigurationSpace( + { + tunable.name: range_hp, + special_name: CategoricalHyperparameter( + name=special_name, + choices=tunable.special, + weights=special_weights, + default_value=( + tunable.default if tunable.default in tunable.special else None + ), + meta=meta, + ), + type_name: CategoricalHyperparameter( + name=type_name, + choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + weights=switch_weights, + default_value=TunableValueKind.SPECIAL, + ), + } + ) + conf_space.add_condition( + EqualsCondition( + conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL + ) + ) + conf_space.add_condition( + EqualsCondition( + conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE + ) + ) return conf_space -def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace: +def tunable_groups_to_configspace( + tunables: TunableGroups, seed: Optional[int] = None +) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -178,11 +202,14 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = A new ConfigurationSpace instance that corresponds to the input TunableGroups. """ space = ConfigurationSpace(seed=seed) - for (tunable, group) in tunables: + for tunable, group in tunables: space.add_configuration_space( - prefix="", delimiter="", + prefix="", + delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost())) + tunable, group.name, group.get_current_cost() + ), + ) return space @@ -201,7 +228,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: A ConfigSpace Configuration. """ values: Dict[str, TunableValue] = {} - for (tunable, _group) in tunables: + for tunable, _group in tunables: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) if tunable.value in tunable.special: @@ -224,7 +251,8 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: data = data.copy() specials = [ special_param_name_strip(k) - for k in data.keys() if special_param_name_is_temp(k) + for k in data.keys() + if special_param_name_is_temp(k) ] for k in specials: (special_name, type_name) = special_param_names(k) diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 4f207f5fc9..4f5efb6aa7 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -28,11 +28,13 @@ class GridSearchOptimizer(TrackBestOptimizer): Grid search optimizer. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) # Track the grid as a set of tuples of tunable values and reconstruct the @@ -51,11 +53,21 @@ def __init__(self, def _sanity_check(self) -> None: size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) if size == np.inf: - raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") + raise ValueError( + f"Unquantized tunables are not supported for grid search: {self._tunables}" + ) if size > 10000: - _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) + _LOG.warning( + "Large number %d of config points requested for grid search: %s", + size, + self._tunables, + ) if size > self._max_iter: - _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) + _LOG.warning( + "Grid search size %d, is greater than max iterations %d", + size, + self._max_iter, + ) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: """ @@ -68,12 +80,14 @@ def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], Non # names instead of the order given by TunableGroups. configs = [ configspace_data_to_tunable_values(dict(config)) - for config in - generate_grid(self.config_space, { - tunable.name: int(tunable.cardinality) - for (tunable, _group) in self._tunables - if tunable.quantization or tunable.type == "int" - }) + for config in generate_grid( + self.config_space, + { + tunable.name: int(tunable.cardinality) + for (tunable, _group) in self._tunables + if tunable.quantization or tunable.type == "int" + }, + ) ] names = set(tuple(configs.keys()) for configs in configs) assert len(names) == 1 @@ -89,7 +103,10 @@ def pending_configs(self) -> Iterable[Dict[str, TunableValue]]: Iterable[Dict[str, TunableValue]] """ # See NOTEs above. - return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys()) + return ( + dict(zip(self._config_keys, config)) + for config in self._pending_configs.keys() + ) @property def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: @@ -101,17 +118,21 @@ def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: Iterable[Dict[str, TunableValue]] """ # See NOTEs above. - return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) - - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + return ( + dict(zip(self._config_keys, config)) for config in self._suggested_configs + ) + + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -152,20 +173,34 @@ def suggest(self) -> TunableGroups: _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) try: - config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) + config = dict( + ConfigSpace.Configuration( + self.config_space, values=tunables.get_param_values() + ) + ) self._suggested_configs.remove(tuple(config.values())) except KeyError: - _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) + _LOG.warning( + "Attempted to remove missing config (previously registered?) from suggested set: %s", + tunables, + ) return registered_score def not_converged(self) -> bool: if self._iter > self._max_iter: if bool(self._pending_configs): - _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", - len(self._pending_configs), list(self._pending_configs.keys())) + _LOG.warning( + "Exceeded max iterations, but still have %d pending configs: %s", + len(self._pending_configs), + list(self._pending_configs.keys()), + ) return False return bool(self._pending_configs) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index d7d50f1ca5..c30134d1b1 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -40,35 +40,42 @@ class MlosCoreOptimizer(Optimizer): A wrapper class for the mlos_core optimizers. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) - opt_type = getattr(OptimizerType, self._config.pop( - 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name)) + opt_type = getattr( + OptimizerType, + self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name), + ) if opt_type == OptimizerType.SMAC: - output_directory = self._config.get('output_directory') + output_directory = self._config.get("output_directory") if output_directory is not None: # If output_directory is specified, turn it into an absolute path. - self._config['output_directory'] = os.path.abspath(output_directory) + self._config["output_directory"] = os.path.abspath(output_directory) else: - _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.") + _LOG.warning( + "SMAC optimizer output_directory was null. SMAC will use a temporary directory." + ) # Make sure max_trials >= max_iterations. - if 'max_trials' not in self._config: - self._config['max_trials'] = self._max_iter - assert int(self._config['max_trials']) >= self._max_iter, \ - f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" + if "max_trials" not in self._config: + self._config["max_trials"] = self._max_iter + assert ( + int(self._config["max_trials"]) >= self._max_iter + ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" - if 'run_name' not in self._config and self.experiment_id: - self._config['run_name'] = self.experiment_id + if "run_name" not in self._config and self.experiment_id: + self._config["run_name"] = self.experiment_id - space_adapter_type = self._config.pop('space_adapter_type', None) - space_adapter_config = self._config.pop('space_adapter_config', {}) + space_adapter_type = self._config.pop("space_adapter_type", None) + space_adapter_config = self._config.pop("space_adapter_config", {}) if space_adapter_type is not None: space_adapter_type = getattr(SpaceAdapterType, space_adapter_type) @@ -82,9 +89,12 @@ def __init__(self, space_adapter_kwargs=space_adapter_config, ) - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: self._opt.cleanup() return super().__exit__(ex_type, ex_val, ex_tb) @@ -92,10 +102,12 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], def name(self) -> str: return f"{self.__class__.__name__}:{self._opt.__class__.__name__}" - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False @@ -103,7 +115,8 @@ def bulk_register(self, df_configs = self._to_df(configs) # Impute missing values, if necessary df_scores = self._adjust_signs_df( - pd.DataFrame([{} if score is None else score for score in scores])) + pd.DataFrame([{} if score is None else score for score in scores]) + ) opt_targets = list(self._opt_targets) if status is not None: @@ -118,7 +131,9 @@ def bulk_register(self, # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register(configs=df_configs, scores=df_scores[opt_targets].astype(float)) + self._opt.register( + configs=df_configs, scores=df_scores[opt_targets].astype(float) + ) if _LOG.isEnabledFor(logging.DEBUG): (score, _) = self.get_best_observation() @@ -130,7 +145,7 @@ def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: """ In-place adjust the signs of the scores for MINIMIZATION problem. """ - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores @@ -152,7 +167,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: df_configs = pd.DataFrame(configs) tunables_names = list(self._tunables.get_param_values().keys()) missing_cols = set(tunables_names).difference(df_configs.columns) - for (tunable, _group) in self._tunables: + for tunable, _group in self._tunables: if tunable.name in missing_cols: df_configs[tunable.name] = tunable.default else: @@ -163,7 +178,9 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) tunables_names += [special_name, type_name] - is_special = df_configs[tunable.name].apply(tunable.special.__contains__) + is_special = df_configs[tunable.name].apply( + tunable.special.__contains__ + ) df_configs[type_name] = TunableValueKind.RANGE df_configs.loc[is_special, type_name] = TunableValueKind.SPECIAL if tunable.type == "int": @@ -185,21 +202,32 @@ def suggest(self) -> TunableGroups: self._start_with_defaults = False _LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config) return tunables.assign( - configspace_data_to_tunable_values(df_config.loc[0].to_dict())) + configspace_data_to_tunable_values(df_config.loc[0].to_dict()) + ) - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: - registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: + registered_score = super().register( + tunables, status, score + ) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None df_config = self._to_df([tunables.get_param_values()]) _LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config) # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float)) + self._opt.register( + configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) + ) return registered_score - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: (df_config, df_score, _df_context) = self._opt.get_best_observations() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index ada4411b58..8dd13eb182 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -24,11 +24,13 @@ class MockOptimizer(TrackBestOptimizer): Mock optimizer to test the Environment API. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) rnd = random.Random(self.seed) self._random: Dict[str, Callable[[Tunable], TunableValue]] = { @@ -37,15 +39,17 @@ def __init__(self, "int": lambda tunable: rnd.randint(*tunable.range), } - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -62,7 +66,7 @@ def suggest(self) -> TunableGroups: _LOG.info("Use default tunable values") self._start_with_defaults = False else: - for (tunable, _group) in tunables: + for tunable, _group in tunables: tunable.value = self._random[tunable.type](tunable) _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index 9ad1070c46..b7a14f8af2 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -24,11 +24,13 @@ class OneShotOptimizer(MockOptimizer): # TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344 - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) _LOG.info("Run a single iteration for: %s", self._tunables) self._max_iter = 1 # Always run for just one iteration. diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index 32a23142e3..0fd54b2dfa 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -24,17 +24,23 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): Base Optimizer class that keeps track of the best score and configuration. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) self._best_config: Optional[TunableGroups] = None self._best_score: Optional[Dict[str, float]] = None - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) if status.is_succeeded() and self._is_better(registered_score): self._best_score = registered_score @@ -48,7 +54,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: if self._best_score is None: return True assert registered_score is not None - for (opt_target, best_score) in self._best_score.items(): + for opt_target, best_score in self._best_score.items(): score = registered_score[opt_target] if score < best_score: return True @@ -56,7 +62,9 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: return False return False - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) score = self._get_scores(Status.SUCCEEDED, self._best_score) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index a7912688a1..7f26851c6b 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,16 +22,19 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[ + str + ] # pylint: disable=protected-access,disable=unsubscriptable-object else: - EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access + EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 -if sys.platform == 'win32': +if sys.platform == "win32": import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) + environ: EnvironType = nt.environ else: environ: EnvironType = os.environ -__all__ = ['environ'] +__all__ = ["environ"] diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py index 85c8c2b0c5..3dc5cbbfd4 100755 --- a/mlos_bench/mlos_bench/run.py +++ b/mlos_bench/mlos_bench/run.py @@ -20,10 +20,13 @@ _LOG = logging.getLogger(__name__) -def _main(argv: Optional[List[str]] = None - ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: +def _main( + argv: Optional[List[str]] = None, +) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: - launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv) + launcher = Launcher( + "mlos_bench", "Systems autotuning and benchmarking tool", argv=argv + ) with launcher.scheduler as scheduler_context: scheduler_context.start() diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index c54e3c0efc..c53d11231d 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ - 'Scheduler', - 'SyncScheduler', + "Scheduler", + "SyncScheduler", ] diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 0b6733e423..c089ff5946 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -31,13 +31,16 @@ class Scheduler(metaclass=ABCMeta): Base class for the optimization loop scheduling policies. """ - def __init__(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: Storage, - root_env_config: str): + def __init__( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: Storage, + root_env_config: str, + ): """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service @@ -60,8 +63,11 @@ def __init__(self, *, Path to the root environment configuration. """ self.global_config = global_config - config = merge_parameters(dest=config.copy(), source=global_config, - required_keys=["experiment_id", "trial_id"]) + config = merge_parameters( + dest=config.copy(), + source=global_config, + required_keys=["experiment_id", "trial_id"], + ) self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) @@ -69,9 +75,13 @@ def __init__(self, *, self._max_trials = int(config.get("max_trials", -1)) self._trial_count = 0 - self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) + self._trial_config_repeat_count = int( + config.get("trial_config_repeat_count", 1) + ) if self._trial_config_repeat_count <= 0: - raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") + raise ValueError( + f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" + ) self._do_teardown = bool(config.get("teardown", True)) @@ -95,7 +105,7 @@ def __repr__(self) -> str: """ return self.__class__.__name__ - def __enter__(self) -> 'Scheduler': + def __enter__(self) -> "Scheduler": """ Enter the scheduler's context. """ @@ -117,10 +127,12 @@ def __enter__(self) -> 'Scheduler': ).__enter__() return self - def __exit__(self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the scheduler. """ @@ -142,8 +154,12 @@ def start(self) -> None: Start the optimization loop. """ assert self.experiment is not None - _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, self.environment, self.optimizer) + _LOG.info( + "START: Experiment: %s Env: %s Optimizer: %s", + self.experiment, + self.environment, + self.optimizer, + ) if _LOG.isEnabledFor(logging.INFO): _LOG.info("Root Environment:\n%s", self.environment.pprint()) @@ -160,7 +176,9 @@ def teardown(self) -> None: if self._do_teardown: self.environment.teardown() - def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: + def get_best_observation( + self, + ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: """ Get the best observation from the optimizer. """ @@ -177,7 +195,9 @@ def load_config(self, config_id: int) -> TunableGroups: tunables = self.environment.tunable_params.assign(tunable_values) _LOG.info("Load config from storage: %d", config_id) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) + _LOG.debug( + "Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2) + ) return tunables def _schedule_new_optimizer_suggestions(self) -> bool: @@ -204,27 +224,33 @@ def schedule_trial(self, tunables: TunableGroups) -> None: Add a configuration to the queue of trials. """ for repeat_i in range(1, self._trial_config_repeat_count + 1): - self._add_trial_to_queue(tunables, config={ - # Add some additional metadata to track for the trial such as the - # optimizer config used. - # Note: these values are unfortunately mutable at the moment. - # Consider them as hints of what the config was the trial *started*. - # It is possible that the experiment configs were changed - # between resuming the experiment (since that is not currently - # prevented). - "optimizer": self.optimizer.name, - "repeat_i": repeat_i, - "is_defaults": tunables.is_defaults, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(self.optimizer.targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) + self._add_trial_to_queue( + tunables, + config={ + # Add some additional metadata to track for the trial such as the + # optimizer config used. + # Note: these values are unfortunately mutable at the moment. + # Consider them as hints of what the config was the trial *started*. + # It is possible that the experiment configs were changed + # between resuming the experiment (since that is not currently + # prevented). + "optimizer": self.optimizer.name, + "repeat_i": repeat_i, + "is_defaults": tunables.is_defaults, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(self.optimizer.targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) - def _add_trial_to_queue(self, tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> None: + def _add_trial_to_queue( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: """ Add a configuration to the queue of trials. A wrapper for the `Experiment.new_trial` method. @@ -257,4 +283,9 @@ def run_trial(self, trial: Storage.Trial) -> None: """ assert self.experiment is not None self._trial_count += 1 - _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) + _LOG.info( + "QUEUE: Execute trial # %d/%d :: %s", + self._trial_count, + self._max_trials, + trial, + ) diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index a73a493533..3e196d4d4f 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -53,7 +53,9 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = self.environment.run() # Block and wait for the final result. + (status, timestamp, results) = ( + self.environment.run() + ) # Block and wait for the final result. _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index bcc7d02d6f..dacbb88126 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -11,7 +11,7 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'Service', - 'FileShareService', - 'LocalExecService', + "Service", + "FileShareService", + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index f00a7a1a00..63c222ee45 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -21,10 +21,13 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): An abstract base of all file shares. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share with a given config. @@ -42,12 +45,16 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) @abstractmethod - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -65,11 +72,18 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ params = params or {} - _LOG.info("Download from File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", remote_path, local_path, params) + _LOG.info( + "Download from File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + remote_path, + local_path, + params, + ) @abstractmethod - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. @@ -86,5 +100,10 @@ def upload(self, params: dict, local_path: str, remote_path: str, recursive: boo if True (the default), upload the entire directory tree. """ params = params or {} - _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", local_path, remote_path, params) + _LOG.info( + "Upload to File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + local_path, + remote_path, + params, + ) diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index e7c9365bf7..724fd6e8f2 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -26,11 +26,13 @@ class Service: """ @classmethod - def new(cls, - class_name: str, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def new( + cls, + class_name: str, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Factory method for a new service with a given config. @@ -57,11 +59,13 @@ def new(cls, assert issubclass(cls, Service) return instantiate_from_config(cls, class_name, config, global_config, parent) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new service with a given config. @@ -96,13 +100,23 @@ def __init__(self, self._config_loader_service = parent if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) - _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) - _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) + _LOG.debug( + "Service: %s Config:\n%s", self, json.dumps(self.config, indent=2) + ) + _LOG.debug( + "Service: %s Globals:\n%s", + self, + json.dumps(global_config or {}, indent=2), + ) + _LOG.debug( + "Service: %s Parent: %s", self, parent.pprint() if parent else None + ) @staticmethod - def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], - local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: + def merge_methods( + ext_methods: Union[Dict[str, Callable], List[Callable], None], + local_methods: Union[Dict[str, Callable], List[Callable]], + ) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. This function is usually called by the derived class constructor @@ -138,9 +152,12 @@ def __enter__(self) -> "Service": self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the Service mix-in context. @@ -177,9 +194,12 @@ def _enter_context(self) -> "Service": self._in_context = True return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exits the context for this particular Service instance. @@ -265,10 +285,12 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None # Unfortunately, by creating a set, we may destroy the ability to # preserve the context enter/exit order, but hopefully it doesn't # matter. - svc_method.__self__ for _, svc_method in self._service_methods.items() + svc_method.__self__ + for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) + if hasattr(svc_method, "__self__") + and isinstance(svc_method.__self__, Service) } def export(self) -> Dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index cac3216d61..85cc849b0e 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -59,13 +59,17 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): Collection of methods to deserialize the Environment, Service, and TunableGroups objects. """ - BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") - - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace( + "\\", "/" + ) + + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of config persistence service. @@ -82,17 +86,22 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.resolve_path, - self.load_config, - self.prepare_class_load, - self.build_service, - self.build_environment, - self.load_services, - self.load_environment, - self.load_environment_list, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.resolve_path, + self.load_config, + self.prepare_class_load, + self.build_service, + self.build_environment, + self.load_services, + self.load_environment, + self.load_environment_list, + ], + ), ) self._config_loader_service = self @@ -120,8 +129,9 @@ def config_paths(self) -> List[str]: """ return list(self._config_path) # make a copy to avoid modifications - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path( + self, file_path: str, extra_paths: Optional[Iterable[str]] = None + ) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -151,10 +161,11 @@ def resolve_path(self, file_path: str, _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config(self, - json_file_name: str, - schema_type: Optional[ConfigSchema], - ) -> Dict[str, Any]: + def load_config( + self, + json_file_name: str, + schema_type: Optional[ConfigSchema], + ) -> Dict[str, Any]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -174,16 +185,22 @@ def load_config(self, """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) - with open(json_file_name, mode='r', encoding='utf-8') as fh_json: + with open(json_file_name, mode="r", encoding="utf-8") as fh_json: config = json5.load(fh_json) if schema_type is not None: try: schema_type.validate(config) except (ValidationError, SchemaError) as ex: - _LOG.error("Failed to validate config %s against schema type %s at %s", - json_file_name, schema_type.name, schema_type.value) - raise ValueError(f"Failed to validate config {json_file_name} against " + - f"schema type {schema_type.name} at {schema_type.value}") from ex + _LOG.error( + "Failed to validate config %s against schema type %s at %s", + json_file_name, + schema_type.name, + schema_type.value, + ) + raise ValueError( + f"Failed to validate config {json_file_name} against " + + f"schema type {schema_type.name} at {schema_type.value}" + ) from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated # them to avoid passing them on to other objects @@ -194,11 +211,14 @@ def load_config(self, del config["$schema"] else: _LOG.warning("Config %s is not validated against a schema.", json_file_name) - return config # type: ignore[no-any-return] + return config # type: ignore[no-any-return] - def prepare_class_load(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: + def prepare_class_load( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + ) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. Mix-in the global parameters and resolve the local file system paths, @@ -232,25 +252,35 @@ def prepare_class_load(self, config: Dict[str, Any], merge_parameters(dest=class_config, source=global_config) - for key in set(class_config).intersection(config.get("resolve_config_property_paths", [])): + for key in set(class_config).intersection( + config.get("resolve_config_property_paths", []) + ): if isinstance(class_config[key], str): class_config[key] = self.resolve_path(class_config[key]) elif isinstance(class_config[key], (list, tuple)): - class_config[key] = [self.resolve_path(path) for path in class_config[key]] + class_config[key] = [ + self.resolve_path(path) for path in class_config[key] + ] else: raise ValueError(f"Parameter {key} must be a string or a list") if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", + class_name, + json.dumps(class_config, indent=2), + ) return (class_name, class_config) - def build_optimizer(self, *, - tunables: TunableGroups, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> Optimizer: + def build_optimizer( + self, + *, + tunables: TunableGroups, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> Optimizer: """ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. @@ -279,18 +309,24 @@ def build_optimizer(self, *, if tunables_path is not None: tunables = self._load_tunables(tunables_path, tunables) (class_name, class_config) = self.prepare_class_load(config, global_config) - inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract] - tunables=tunables, - config=class_config, - global_config=global_config, - service=service) + inst = instantiate_from_config( + Optimizer, + class_name, # type: ignore[type-abstract] + tunables=tunables, + config=class_config, + global_config=global_config, + service=service, + ) _LOG.info("Created: Optimizer %s", inst) return inst - def build_storage(self, *, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> "Storage": + def build_storage( + self, + *, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> "Storage": """ Instantiation of mlos_bench Storage objects. @@ -312,20 +348,27 @@ def build_storage(self, *, from mlos_bench.storage.base_storage import ( Storage, # pylint: disable=import-outside-toplevel ) - inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - service=service) + + inst = instantiate_from_config( + Storage, + class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + service=service, + ) _LOG.info("Created: Storage %s", inst) return inst - def build_scheduler(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: "Storage", - root_env_config: str) -> "Scheduler": + def build_scheduler( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: "Storage", + root_env_config: str, + ) -> "Scheduler": """ Instantiation of mlos_bench Scheduler. @@ -353,22 +396,28 @@ def build_scheduler(self, *, from mlos_bench.schedulers.base_scheduler import ( Scheduler, # pylint: disable=import-outside-toplevel ) - inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - environment=environment, - optimizer=optimizer, - storage=storage, - root_env_config=root_env_config) + + inst = instantiate_from_config( + Scheduler, + class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + environment=environment, + optimizer=optimizer, + storage=storage, + root_env_config=root_env_config, + ) _LOG.info("Created: Scheduler %s", inst) return inst - def build_environment(self, # pylint: disable=too-many-arguments - config: Dict[str, Any], - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def build_environment( + self, # pylint: disable=too-many-arguments + config: Dict[str, Any], + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Factory method for a new environment with a given config. @@ -397,7 +446,9 @@ def build_environment(self, # pylint: disable=too-many-arguments An instance of the `Environment` class initialized with `config`. """ env_name = config["name"] - (env_class, env_config) = self.prepare_class_load(config, global_config, parent_args) + (env_class, env_config) = self.prepare_class_load( + config, global_config, parent_args + ) env_services_path = config.get("include_services") if env_services_path is not None: @@ -408,16 +459,24 @@ def build_environment(self, # pylint: disable=too-many-arguments tunables = self._load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new(env_name=env_name, class_name=env_class, - config=env_config, global_config=global_config, - tunables=tunables, service=service) + env = Environment.new( + env_name=env_name, + class_name=env_class, + config=env_config, + global_config=global_config, + tunables=tunables, + service=service, + ) _LOG.info("Created env: %s :: %s", env_name, env) return env - def _build_standalone_service(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_standalone_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -442,9 +501,12 @@ def _build_standalone_service(self, config: Dict[str, Any], _LOG.info("Created service: %s", service) return service - def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_composite_service( + self, + config_list: Iterable[Dict[str, Any]], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -470,18 +532,21 @@ def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], service.register(parent.export()) for config in config_list: - service.register(self._build_standalone_service( - config, global_config, service).export()) + service.register( + self._build_standalone_service(config, global_config, service).export() + ) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Created mix-in service: %s", service) return service - def build_service(self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def build_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -503,8 +568,7 @@ def build_service(self, services from the list plus the parent mix-in. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Build service from config:\n%s", - json.dumps(config, indent=2)) + _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2)) assert isinstance(config, dict) config_list: List[Dict[str, Any]] @@ -519,12 +583,14 @@ def build_service(self, return self._build_composite_service(config_list, global_config, parent) - def load_environment(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def load_environment( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Load and build new environment from the config file. @@ -549,14 +615,18 @@ def load_environment(self, # pylint: disable=too-many-arguments """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) assert isinstance(config, dict) - return self.build_environment(config, tunables, global_config, parent_args, service) + return self.build_environment( + config, tunables, global_config, parent_args, service + ) - def load_environment_list(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> List[Environment]: + def load_environment_list( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -582,12 +652,17 @@ def load_environment_list(self, # pylint: disable=too-many-arguments """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) return [ - self.build_environment(config, tunables, global_config, parent_args, service) + self.build_environment( + config, tunables, global_config, parent_args, service + ) ] - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -606,16 +681,20 @@ def load_services(self, json_file_names: Iterable[str], service : Service A collection of service methods. """ - _LOG.info("Load services: %s parent: %s", - json_file_names, parent.__class__.__name__) + _LOG.info( + "Load services: %s parent: %s", json_file_names, parent.__class__.__name__ + ) service = Service({}, global_config, parent) for fname in json_file_names: config = self.load_config(fname, ConfigSchema.SERVICE) - service.register(self.build_service(config, global_config, service).export()) + service.register( + self.build_service(config, global_config, service).export() + ) return service - def _load_tunables(self, json_file_names: Iterable[str], - parent: TunableGroups) -> TunableGroups: + def _load_tunables( + self, json_file_names: Iterable[str], parent: TunableGroups + ) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent TunableGroup. diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index abb87c8b52..b9d0c267c1 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -9,5 +9,5 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'LocalExecService', + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 47534be7b1..6b9bca1a0c 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -79,11 +79,13 @@ class LocalExecService(TempDirContextService, SupportsLocalExec): due to reduced dependency management complications vs the target environment. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service to run scripts locally. @@ -100,14 +102,19 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, + global_config, + parent, + self.merge_methods(methods, [self.local_exec]), ) self.abort_on_error = self.config.get("abort_on_error", True) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -133,7 +140,9 @@ def local_exec(self, script_lines: Iterable[str], _LOG.debug("Run in directory: %s", temp_dir) for line in script_lines: - (return_code, stdout, stderr) = self._local_exec_script(line, env, temp_dir) + (return_code, stdout, stderr) = self._local_exec_script( + line, env, temp_dir + ) stdout_list.append(stdout) stderr_list.append(stderr) if return_code != 0 and self.abort_on_error: @@ -175,9 +184,12 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: subcmd_tokens.insert(0, sys.executable) return subcmd_tokens - def _local_exec_script(self, script_line: str, - env_params: Optional[Mapping[str, "TunableValue"]], - cwd: str) -> Tuple[int, str, str]: + def _local_exec_script( + self, + script_line: str, + env_params: Optional[Mapping[str, "TunableValue"]], + cwd: str, + ) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -206,7 +218,7 @@ def _local_exec_script(self, script_line: str, if env_params: env = {key: str(val) for (key, val) in env_params.items()} - if sys.platform == 'win32': + if sys.platform == "win32": # A hack to run Python on Windows with env variables set: env_copy = environ.copy() env_copy["PYTHONPATH"] = "" @@ -214,16 +226,25 @@ def _local_exec_script(self, script_line: str, env = env_copy try: - if sys.platform != 'win32': + if sys.platform != "win32": cmd = [" ".join(cmd)] _LOG.info("Run: %s", cmd) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env)) + _LOG.debug( + "Expands to: %s", Template(" ".join(cmd)).safe_substitute(env) + ) _LOG.debug("Current working dir: %s", cwd) - proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, - text=True, check=False, capture_output=True) + proc = subprocess.run( + cmd, + env=env or None, + cwd=cwd, + shell=True, + text=True, + check=False, + capture_output=True, + ) _LOG.debug("Run: return code = %d", proc.returncode) return (proc.returncode, proc.stdout, proc.stderr) diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index a0cf3e0e57..cdfe510799 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -28,11 +28,13 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): This class is not supposed to be used as a standalone service. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service that provides temporary directory context for local exec service. @@ -50,18 +52,24 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.temp_dir_context]) + config, + global_config, + parent, + self.merge_methods(methods, [self.temp_dir_context]), ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: # expand globals - self._temp_dir = Template(self._temp_dir).safe_substitute(global_config or {}) + self._temp_dir = Template(self._temp_dir).safe_substitute( + global_config or {} + ) # and resolve the path to absolute path self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir) _LOG.info("%s: temp dir: %s", self, self._temp_dir) - def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 61a6c74942..12fe62eeb7 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -13,9 +13,9 @@ from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService __all__ = [ - 'AzureAuthService', - 'AzureFileShareService', - 'AzureNetworkService', - 'AzureSaaSConfigService', - 'AzureVMService', + "AzureAuthService", + "AzureFileShareService", + "AzureNetworkService", + "AzureSaaSConfigService", + "AzureVMService", ] diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 4121446caf..a5a6bc549c 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -27,13 +27,15 @@ class AzureAuthService(Service, SupportsAuth): Helper methods to get access to Azure services. """ - _REQ_INTERVAL = 300 # = 5 min - - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + _REQ_INTERVAL = 300 # = 5 min + + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure authentication services proxy. @@ -50,18 +52,27 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) # This parameter can come from command line as strings, so conversion is needed. - self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL)) + self._req_interval = float( + self.config.get("tokenRequestInterval", self._REQ_INTERVAL) + ) self._access_token = "RENEW *NOW*" - self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. + self._token_expiration_ts = datetime.now( + UTC + ) # Typically, some future timestamp. # Login as ourselves self._cred: Union[azure_id.AzureCliCredential, azure_id.CertificateCredential] @@ -70,12 +81,13 @@ def __init__(self, # Verify info required for SP auth early if "spClientId" in self.config: check_required_params( - self.config, { + self.config, + { "spClientId", "keyVaultName", "certName", "tenant", - } + }, ) def _init_sp(self) -> None: @@ -104,7 +116,9 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) + self._cred = azure_id.CertificateCredential( + tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes + ) def get_access_token(self) -> str: """ @@ -121,7 +135,9 @@ def get_access_token(self) -> str: res = self._cred.get_token("https://management.azure.com/.default") self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) self._access_token = res.token - _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) + _LOG.info( + "Got new accessToken. Expiration time: %s", self._token_expiration_ts + ) return self._access_token def get_auth_headers(self) -> dict: diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 9f2b504aff..a494867aa0 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -29,9 +29,9 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): Helper methods to manage and deploy Azure resources via REST APIs. """ - _POLL_INTERVAL = 4 # seconds - _POLL_TIMEOUT = 300 # seconds - _REQUEST_TIMEOUT = 5 # seconds + _POLL_INTERVAL = 4 # seconds + _POLL_TIMEOUT = 300 # seconds + _REQUEST_TIMEOUT = 5 # seconds _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) @@ -39,19 +39,21 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): # https://docs.microsoft.com/en-us/rest/api/resources/deployments _URL_DEPLOY = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Resources" + - "/deployments/{deployment_name}" + - "?api-version=2022-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Resources" + + "/deployments/{deployment_name}" + + "?api-version=2022-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an Azure Services proxy. @@ -69,32 +71,50 @@ def __init__(self, """ super().__init__(config, global_config, parent, methods) - check_required_params(self.config, [ - "subscription", - "resourceGroup", - ]) + check_required_params( + self.config, + [ + "subscription", + "resourceGroup", + ], + ) # These parameters can come from command line as strings, so conversion is needed. - self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) + self._poll_interval = float( + self.config.get("pollInterval", self._POLL_INTERVAL) + ) self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) - self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) - self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) + self._request_timeout = float( + self.config.get("requestTimeout", self._REQUEST_TIMEOUT) + ) + self._total_retries = int( + self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) + ) + self._backoff_factor = float( + self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) + ) self._deploy_template = {} self._deploy_params = {} if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config['deploymentTemplatePath'], schema_type=None) + self.config["deploymentTemplatePath"], schema_type=None + ) assert template is not None and isinstance(template, dict) self._deploy_template = template # Allow for recursive variable expansion as we do with global params and const_args. - deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) - self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) + deploy_params = DictTemplater( + self.config["deploymentTemplateParameters"] + ).expand_vars(extra_source_dict=global_config) + self._deploy_params = merge_parameters( + dest=deploy_params, source=global_config + ) else: - _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") + _LOG.info( + "No deploymentTemplatePath provided. Deployment services will be unavailable." + ) @property def deploy_params(self) -> dict: @@ -129,7 +149,10 @@ def _get_session(self, params: dict) -> requests.Session: session = requests.Session() session.mount( "https://", - HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) + HTTPAdapter( + max_retries=Retry(total=total_retries, backoff_factor=backoff_factor) + ), + ) session.headers.update(self._get_headers()) return session @@ -137,8 +160,9 @@ def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @staticmethod @@ -153,11 +177,15 @@ def _extract_arm_parameters(json_data: dict) -> dict: """ return { key: val.get("value") - for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() + for (key, val) in json_data.get("properties", {}) + .get("parameters", {}) + .items() if val.get("value") is not None } - def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]: + def _azure_rest_api_post_helper( + self, params: dict, url: str + ) -> Tuple[Status, dict]: """ General pattern for performing an action on an Azure resource via its REST API. @@ -179,7 +207,9 @@ def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, d """ _LOG.debug("Request: POST %s", url) - response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) + response = requests.post( + url, headers=self._get_headers(), timeout=self._request_timeout + ) _LOG.debug("Response: %s", response) # Logical flow for async operations based on: @@ -227,16 +257,20 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: try: response = session.get(url, timeout=self._request_timeout) except requests.exceptions.ReadTimeout: - _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) + _LOG.warning( + "Request timed out after %.2f s: %s", self._request_timeout, url + ) return Status.RUNNING, {} except requests.exceptions.RequestException as ex: _LOG.exception("Error in request checking operation status", exc_info=ex) return (Status.FAILED, {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) if response.status_code == 200: output = response.json() @@ -269,12 +303,19 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ params = self._set_default_params(params) - _LOG.info("Wait for %s to %s", params.get("deploymentName"), - "provision" if is_setup else "deprovision") + _LOG.info( + "Wait for %s to %s", + params.get("deploymentName"), + "provision" if is_setup else "deprovision", + ) return self._wait_while(self._check_deployment, Status.PENDING, params) - def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], - loop_status: Status, params: dict) -> Tuple[Status, dict]: + def _wait_while( + self, + func: Callable[[dict], Tuple[Status, dict]], + loop_status: Status, + params: dict, + ) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -296,12 +337,18 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) poll_period = params.get("pollInterval", self._poll_interval) - _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", - config["deploymentName"], loop_status, poll_period, self._poll_timeout) + _LOG.debug( + "Wait for %s status %s :: poll %.2f timeout %d s", + config["deploymentName"], + loop_status, + poll_period, + self._poll_timeout, + ) ts_timeout = time.time() + self._poll_timeout poll_delay = poll_period @@ -325,7 +372,9 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment( + self, params: dict + ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -351,7 +400,7 @@ def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: di "subscription", "resourceGroup", "deploymentName", - ] + ], ) _LOG.info("Check deployment: %s", config["deploymentName"]) @@ -366,7 +415,9 @@ def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: di try: response = session.get(url, timeout=self._request_timeout) except requests.exceptions.ReadTimeout: - _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) + _LOG.warning( + "Request timed out after %.2f s: %s", self._request_timeout, url + ) return Status.RUNNING, {} except requests.exceptions.RequestException as ex: _LOG.exception("Error in request checking deployment", exc_info=ex) @@ -412,13 +463,18 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: if not self._deploy_template: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) - config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + config = merge_parameters( + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) params = merge_parameters(dest=self._deploy_params.copy(), source=params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Deploy: %s merged params ::\n%s", - config["deploymentName"], json.dumps(params, indent=2)) + _LOG.debug( + "Deploy: %s merged params ::\n%s", + config["deploymentName"], + json.dumps(params, indent=2), + ) url = self._URL_DEPLOY.format( subscription=config["subscription"], @@ -431,22 +487,29 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: "mode": "Incremental", "template": self._deploy_template, "parameters": { - key: {"value": val} for (key, val) in params.items() + key: {"value": val} + for (key, val) in params.items() if key in self._deploy_template.get("parameters", {}) - } + }, } } if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.put(url, json=json_req, - headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put( + url, + json=json_req, + headers=self._get_headers(), + timeout=self._request_timeout, + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 6ccd4ba09d..bec45a967d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -27,11 +27,13 @@ class AzureFileShareService(FileShareService): _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share Service for Azure environments with a given config. @@ -49,16 +51,19 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) check_required_params( - self.config, { + self.config, + { "storageAccountName", "storageFileShareName", "storageAccountKey", - } + }, ) self._share_client = ShareClient.from_share_url( @@ -69,7 +74,9 @@ def __init__(self, credential=self.config["storageAccountKey"], ) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -94,11 +101,15 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b # Translate into non-Azure exception: raise FileNotFoundError(f"Cannot download: {remote_path}") from ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) - def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: + def _upload( + self, local_path: str, remote_path: str, recursive: bool, seen: Set[str] + ) -> None: """ Upload contents from a local path to an Azure file share. This method is called from `.upload()` above. We need it to avoid exposing diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index d65ee02cfd..95e16892cc 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -32,20 +32,22 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Network" + - "/virtualNetwork/{vnet_name}" + - "/delete" + - "?api-version=2023-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Network" + + "/virtualNetwork/{vnet_name}" + + "/delete" + + "?api-version=2023-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure Network services proxy. @@ -62,28 +64,40 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsNetworkProvisioning - self.provision_network, - self.deprovision_network, - self.wait_network_deployment, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsNetworkProvisioning + self.provision_network, + self.deprovision_network, + self.wait_network_deployment, + ], + ), ) if not self._deploy_template: - raise ValueError("AzureNetworkService requires a deployment template:\n" - + f"config={config}\nglobal_config={global_config}") + raise ValueError( + "AzureNetworkService requires a deployment template:\n" + + f"config={config}\nglobal_config={global_config}" + ) - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", + params["deploymentName"], + ) return params - def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: + def wait_network_deployment( + self, params: dict, *, is_setup: bool + ) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -124,7 +138,9 @@ def provision_network(self, params: dict) -> Tuple[Status, dict]: """ return self._provision_resource(params) - def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple[Status, dict]: + def deprovision_network( + self, params: dict, ignore_errors: bool = True + ) -> Tuple[Status, dict]: """ Deprovisions the virtual network on Azure by deleting it. @@ -151,15 +167,18 @@ def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple "resourceGroup", "deploymentName", "vnetName", - ] + ], ) _LOG.info("Deprovision Network: %s", config["vnetName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) - (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vnet_name=config["vnetName"], - )) + (status, results) = self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vnet_name=config["vnetName"], + ), + ) if ignore_errors and status == Status.FAILED: _LOG.warning("Ignoring error: %s", results) status = Status.SUCCEEDED diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index a92d279a6d..03928a4b18 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -32,20 +32,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations _URL_CONFIGURE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/{provider}" + - "/{server_type}/{vm_name}" + - "/{update}" + - "?api-version={api_version}" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/{provider}" + + "/{server_type}/{vm_name}" + + "/{update}" + + "?api-version={api_version}" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure services proxy. @@ -62,18 +64,20 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.configure, - self.is_config_pending - ]) + config, + global_config, + parent, + self.merge_methods(methods, [self.configure, self.is_config_pending]), ) - check_required_params(self.config, { - "subscription", - "resourceGroup", - "provider", - }) + check_required_params( + self.config, + { + "subscription", + "resourceGroup", + "provider", + }, + ) # Provide sane defaults for known DB providers. provider = self.config.get("provider") @@ -100,7 +104,11 @@ def __init__(self, provider=self.config["provider"], vm_name="{vm_name}", server_type="flexibleServers" if is_flex else "servers", - update="updateConfigurations" if self._is_batch else "configurations/{param_name}", + update=( + "updateConfigurations" + if self._is_batch + else "configurations/{param_name}" + ), api_version=api_version, ) @@ -115,10 +123,13 @@ def __init__(self, ) # These parameters can come from command line as strings, so conversion is needed. - self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) + self._request_timeout = float( + self.config.get("requestTimeout", self._REQUEST_TIMEOUT) + ) - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def configure( + self, config: Dict[str, Any], params: Dict[str, Any] + ) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service. @@ -157,32 +168,43 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + dest=self.config.copy(), source=config, required_keys=["vmName"] + ) url = self._url_config_get.format(vm_name=config["vmName"]) _LOG.debug("Request: GET %s", url) response = requests.put( - url, headers=self._get_headers(), timeout=self._request_timeout) + url, headers=self._get_headers(), timeout=self._request_timeout + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) if response.status_code != 200: return (Status.FAILED, {}) # Currently, Azure Flex servers require a VM reboot. - return (Status.SUCCEEDED, {"isConfigPendingReboot": any( - {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] - for val in response.json()['value'] - )}) + return ( + Status.SUCCEEDED, + { + "isConfigPendingReboot": any( + {"False": False, "True": True}[ + val["properties"]["isConfigPendingRestart"] + ] + for val in response.json()["value"] + ) + }, + ) def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() - def _config_one(self, config: Dict[str, Any], - param_name: str, param_value: Any) -> Tuple[Status, dict]: + def _config_one( + self, config: Dict[str, Any], param_name: str, param_value: Any + ) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -202,12 +224,18 @@ def _config_one(self, config: Dict[str, Any], Status is one of {PENDING, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) - url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) + dest=self.config.copy(), source=config, required_keys=["vmName"] + ) + url = self._url_config_set.format( + vm_name=config["vmName"], param_name=param_name + ) _LOG.debug("Request: PUT %s", url) - response = requests.put(url, headers=self._get_headers(), - json={"properties": {"value": str(param_value)}}, - timeout=self._request_timeout) + response = requests.put( + url, + headers=self._get_headers(), + json={"properties": {"value": str(param_value)}}, + timeout=self._request_timeout, + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) @@ -215,8 +243,9 @@ def _config_one(self, config: Dict[str, Any], return (Status.SUCCEEDED, {}) return (Status.FAILED, {}) - def _config_many(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_many( + self, config: Dict[str, Any], params: Dict[str, Any] + ) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service one-by-one. (If batch API is not available for it). @@ -234,14 +263,15 @@ def _config_many(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - for (param_name, param_value) in params.items(): + for param_name, param_value in params.items(): (status, result) = self._config_one(config, param_name, param_value) if not status.is_succeeded(): return (status, result) return (Status.SUCCEEDED, {}) - def _config_batch(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_batch( + self, config: Dict[str, Any], params: Dict[str, Any] + ) -> Tuple[Status, dict]: """ Batch update the parameters of an Azure DB service. @@ -259,7 +289,8 @@ def _config_batch(self, config: Dict[str, Any], Status is one of {PENDING, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + dest=self.config.copy(), source=config, required_keys=["vmName"] + ) url = self._url_config_set.format(vm_name=config["vmName"]) json_req = { "value": [ @@ -269,8 +300,12 @@ def _config_batch(self, config: Dict[str, Any], # "resetAllToDefault": "True" } _LOG.debug("Request: POST %s", url) - response = requests.post(url, headers=self._get_headers(), - json=json_req, timeout=self._request_timeout) + response = requests.post( + url, + headers=self._get_headers(), + json=json_req, + timeout=self._request_timeout, + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index ddce3cc935..5f79219c08 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -26,7 +26,13 @@ _LOG = logging.getLogger(__name__) -class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): +class AzureVMService( + AzureDeploymentService, + SupportsHostProvisioning, + SupportsHostOps, + SupportsOSOps, + SupportsRemoteExec, +): """ Helper methods to manage VMs on Azure. """ @@ -38,35 +44,35 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/start" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/start" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/powerOff" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/powerOff" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/deallocate" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/deallocate" + + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -88,31 +94,33 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/restart" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/restart" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/runCommand" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/runCommand" + + "?api-version=2022-03-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure VM services proxy. @@ -129,26 +137,31 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsHostProvisioning - self.provision_host, - self.deprovision_host, - self.deallocate_host, - self.wait_host_deployment, - # SupportsHostOps - self.start_host, - self.stop_host, - self.restart_host, - self.wait_host_operation, - # SupportsOSOps - self.shutdown, - self.reboot, - self.wait_os_operation, - # SupportsRemoteExec - self.remote_exec, - self.get_remote_exec_results, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsHostProvisioning + self.provision_host, + self.deprovision_host, + self.deallocate_host, + self.wait_host_deployment, + # SupportsHostOps + self.start_host, + self.stop_host, + self.restart_host, + self.wait_host_operation, + # SupportsOSOps + self.shutdown, + self.reboot, + self.wait_os_operation, + # SupportsRemoteExec + self.remote_exec, + self.get_remote_exec_results, + ], + ), ) # As a convenience, allow reading customData out of a file, rather than @@ -157,22 +170,29 @@ def __init__(self, # can be done using the `base64()` string function inside the ARM template. self._custom_data_file = self.config.get("customDataFile", None) if self._custom_data_file: - if self._deploy_params.get('customData', None): + if self._deploy_params.get("customData", None): raise ValueError("Both customDataFile and customData are specified.") - self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file) - with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh: + self._custom_data_file = self.config_loader_service.resolve_path( + self._custom_data_file + ) + with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: self._deploy_params["customData"] = custom_data_fh.read() - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", + params["deploymentName"], + ) return params - def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: + def wait_host_deployment( + self, params: dict, *, is_setup: bool + ) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -264,16 +284,19 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: "resourceGroup", "deploymentName", "vmName", - ] + ], ) _LOG.info("Deprovision VM: %s", config["vmName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) # TODO: Properly deprovision *all* resources specified in the ARM template. - return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -301,14 +324,17 @@ def deallocate_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Deallocate VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEALLOCATE.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def start_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -333,14 +359,17 @@ def start_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Start VM: %s :: %s", config["vmName"], params) - return self._azure_rest_api_post_helper(config, self._URL_START.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_START.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: """ @@ -367,14 +396,17 @@ def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Stop VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_STOP.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_STOP.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.stop_host(params, force) @@ -404,20 +436,24 @@ def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict] "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Reboot VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_REBOOT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple[Status, dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple[Status, dict]: """ Run a command on Azure VM. @@ -447,16 +483,20 @@ def remote_exec(self, script: Iterable[str], config: dict, "subscription", "resourceGroup", "vmName", - ] + ], ) if _LOG.isEnabledFor(logging.INFO): - _LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script)) + _LOG.info( + "Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script) + ) json_req = { "commandId": "RunShellScript", "script": list(script), - "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()] + "parameters": [ + {"name": key, "value": val} for (key, val) in env_params.items() + ], } url = self._URL_REXEC_RUN.format( @@ -469,12 +509,18 @@ def remote_exec(self, script: Iterable[str], config: dict, _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout) + url, + json=json_req, + headers=self._get_headers(), + timeout=self._request_timeout, + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) @@ -482,10 +528,13 @@ def remote_exec(self, script: Iterable[str], config: dict, # TODO: extract the results from JSON response return (Status.SUCCEEDED, config) elif response.status_code == 202: - return (Status.PENDING, { - **config, - "asyncResultsUrl": response.headers.get("Azure-AsyncOperation") - }) + return ( + Status.PENDING, + { + **config, + "asyncResultsUrl": response.headers.get("Azure-AsyncOperation"), + }, + ) else: _LOG.error("Response: %s :: %s", response, response.text) # _LOG.error("Bad Request:\n%s", response.request.body) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f623cdfcc8..19290886e4 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -31,9 +31,14 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" - async def _start_file_copy(self, params: dict, mode: CopyMode, - local_path: str, remote_path: str, - recursive: bool = True) -> None: + async def _start_file_copy( + self, + params: dict, + mode: CopyMode, + local_path: str, + remote_path: str, + recursive: bool = True, + ) -> None: # pylint: disable=too-many-arguments """ Starts a file copy operation @@ -71,44 +76,74 @@ async def _start_file_copy(self, params: dict, mode: CopyMode, dstpath = (connection, remote_path) else: raise ValueError(f"Unknown copy mode: {mode}") - return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) + return await scp( + srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True + ) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) + self._start_file_copy( + params, CopyMode.DOWNLOAD, local_path, remote_path, recursive + ) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) + _LOG.error( + "Failed to download %s to %s from %s: %s", + remote_path, + local_path, + params, + ex, + ) if isinstance(ex, SFTPNoSuchFile) or ( - isinstance(ex, SFTPFailure) and ex.code == 4 - and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) + isinstance(ex, SFTPFailure) + and ex.code == 4 + and any( + msg.lower() in ex.reason.lower() + for msg in ("File not found", "No such file or directory") + ) ): _LOG.warning("File %s does not exist on %s", remote_path, params) - raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex + raise FileNotFoundError( + f"File {remote_path} does not exist on {params}" + ) from ex raise ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().upload(params, local_path, remote_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) + self._start_file_copy( + params, CopyMode.UPLOAD, local_path, remote_path, recursive + ) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) + _LOG.error( + "Failed to upload %s to %s on %s: %s", + local_path, + remote_path, + params, + ex, + ) raise ex diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index a650ff0707..dad7cb971c 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -29,11 +29,13 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an SSH Service. @@ -52,17 +54,25 @@ def __init__(self, # Same methods are also provided by the AzureVMService class # pylint: disable=duplicate-code super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.shutdown, - self.reboot, - self.wait_os_operation, - self.remote_exec, - self.get_remote_exec_results, - ])) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ], + ), + ) self._shell = self.config.get("ssh_shell", "/bin/bash") - async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: + async def _run_cmd( + self, params: dict, script: Iterable[str], env_params: dict + ) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. @@ -84,17 +94,22 @@ async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) connection, _ = await self._get_client_connection(params) # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. - env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] - script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] + env_script_lines = [ + f"export {name}='{value}'" for (name, value) in env_params.items() + ] + script_lines = env_script_lines + [ + line_split for line in script for line_split in line.splitlines() + ] # Note: connection.run() uses "exec" with a shell by default. - script_str = '\n'.join(script_lines) + script_str = "\n".join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) - return await connection.run(script_str, - check=False, - timeout=self._request_timeout, - env=env_params) + return await connection.run( + script_str, check=False, timeout=self._request_timeout, env=env_params + ) - def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -121,9 +136,11 @@ def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> source=config, required_keys=[ "ssh_hostname", - ] + ], + ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine( + self._run_cmd(config, script, env_params) ) - config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) return (Status.PENDING, config) def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: @@ -151,10 +168,22 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: try: result = future.result(timeout=self._request_timeout) assert isinstance(result, SSHCompletedProcess) - stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout - stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr + stdout = ( + result.stdout.decode() + if isinstance(result.stdout, bytes) + else result.stdout + ) + stderr = ( + result.stderr.decode() + if isinstance(result.stderr, bytes) + else result.stderr + ) return ( - Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, + ( + Status.SUCCEEDED + if result.exit_status == 0 and result.returncode == 0 + else Status.FAILED + ), { "stdout": stdout, "stderr": stderr, @@ -165,7 +194,9 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: _LOG.error("Failed to get remote exec results: %s", ex) return (Status.FAILED, {"result": result}) - def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: + def _exec_os_op( + self, cmd_opts_list: List[str], params: dict + ) -> Tuple[Status, dict]: """_summary_ Parameters @@ -186,9 +217,9 @@ def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, d source=params, required_keys=[ "ssh_hostname", - ] + ], ) - cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) + cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) script = rf""" if [[ $EUID -ne 0 ]]; then sudo=$(command -v sudo) @@ -223,10 +254,10 @@ def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -h now', - 'poweroff', - 'halt -p', - 'systemctl poweroff', + "shutdown -h now", + "poweroff", + "halt -p", + "systemctl poweroff", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) @@ -248,11 +279,11 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -r now', - 'reboot', - 'halt --reboot', - 'systemctl reboot', - 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', + "shutdown -r now", + "reboot", + "halt --reboot", + "systemctl reboot", + "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 8bc90eb3da..ae18ad4834 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -50,8 +50,8 @@ class SshClient(asyncssh.SSHClient): reconnect for each command. """ - _CONNECTION_PENDING = 'INIT' - _CONNECTION_LOST = 'LOST' + _CONNECTION_PENDING = "INIT" + _CONNECTION_LOST = "LOST" def __init__(self, *args: tuple, **kwargs: dict): self._connection_id: str = SshClient._CONNECTION_PENDING @@ -65,7 +65,7 @@ def __repr__(self) -> str: @staticmethod def id_from_connection(connection: SSHClientConnection) -> str: """Gets a unique id repr for the connection.""" - return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access @staticmethod def id_from_params(connect_params: dict) -> str: @@ -79,8 +79,12 @@ def connection_made(self, conn: SSHClientConnection) -> None: Changes the connection_id from _CONNECTION_PENDING to a unique id repr. """ self._conn_event.clear() - _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ - # pylint: disable=protected-access + _LOG.debug( + "%s: Connection made by %s: %s", + current_thread().name, + conn._options.env, + conn, + ) # pylint: disable=protected-access self._connection_id = SshClient.id_from_connection(conn) self._connection = conn self._conn_event.set() @@ -90,9 +94,19 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self._conn_event.clear() _LOG.debug("%s: %s", current_thread().name, "connection_lost") if exc is None: - _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: gracefully disconnected ssh from %s: %s", + current_thread().name, + self._connection_id, + exc, + ) else: - _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: ssh connection lost on %s: %s", + current_thread().name, + self._connection_id, + exc, + ) self._connection_id = SshClient._CONNECTION_LOST self._connection = None self._conn_event.set() @@ -104,7 +118,11 @@ async def connection(self) -> Optional[SSHClientConnection]: """ _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) await self._conn_event.wait() - _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) + _LOG.debug( + "%s: Connection available for %s", + current_thread().name, + self._connection_id, + ) return self._connection @@ -145,7 +163,9 @@ def exit(self) -> None: warn(RuntimeWarning("SshClientCache lock was still held on exit.")) self._cache_lock.release() - async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: + async def get_client_connection( + self, connect_params: dict + ) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. @@ -159,33 +179,57 @@ async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientCo Tuple[SSHClientConnection, SshClient] A tuple of (SSHClientConnection, SshClient). """ - _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) + _LOG.debug( + "%s: get_client_connection: %s", current_thread().name, connect_params + ) async with self._cache_lock: connection_id = SshClient.id_from_params(connect_params) client: Union[None, SshClient, asyncssh.SSHClient] _, client = self._cache.get(connection_id, (None, None)) if client: - _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) + _LOG.debug( + "%s: Checking cached client %s", + current_thread().name, + connection_id, + ) connection = await client.connection() if not connection: - _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) + _LOG.debug( + "%s: Removing stale client connection %s from cache.", + current_thread().name, + connection_id, + ) self._cache.pop(connection_id) # Try to reconnect next. else: - _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) + _LOG.debug( + "%s: Using cached client %s", + current_thread().name, + connection_id, + ) if connection_id not in self._cache: - _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) - connection, client = await asyncssh.create_connection(SshClient, **connect_params) + _LOG.debug( + "%s: Establishing client connection to %s", + current_thread().name, + connection_id, + ) + connection, client = await asyncssh.create_connection( + SshClient, **connect_params + ) assert isinstance(client, SshClient) self._cache[connection_id] = (connection, client) - _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) + _LOG.debug( + "%s: Created connection to %s.", + current_thread().name, + connection_id, + ) return self._cache[connection_id] def cleanup(self) -> None: """ Closes all cached connections. """ - for (connection, _) in self._cache.values(): + for connection, _ in self._cache.values(): connection.close() self._cache = {} @@ -225,24 +269,28 @@ class SshService(Service, metaclass=ABCMeta): _REQUEST_TIMEOUT: Optional[float] = None # seconds - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__(config, global_config, parent, methods) # Make sure that the value we allow overriding on a per-connection # basis are present in the config so merge_parameters can do its thing. - self.config.setdefault('ssh_port', None) - assert isinstance(self.config['ssh_port'], (int, type(None))) - self.config.setdefault('ssh_username', None) - assert isinstance(self.config['ssh_username'], (str, type(None))) - self.config.setdefault('ssh_priv_key_path', None) - assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) + self.config.setdefault("ssh_port", None) + assert isinstance(self.config["ssh_port"], (int, type(None))) + self.config.setdefault("ssh_username", None) + assert isinstance(self.config["ssh_username"], (str, type(None))) + self.config.setdefault("ssh_priv_key_path", None) + assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) # None can be used to disable the request timeout. - self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) + self._request_timeout = self.config.get( + "ssh_request_timeout", self._REQUEST_TIMEOUT + ) self._request_timeout = nullable(float, self._request_timeout) # Prep an initial connect_params. @@ -250,24 +298,32 @@ def __init__(self, # In general scripted commands shouldn't need a pty and having one # available can confuse some commands, though we may need to make # this configurable in the future. - 'request_pty': False, + "request_pty": False, # By default disable known_hosts checking (since most VMs expected to be dynamically created). - 'known_hosts': None, + "known_hosts": None, } - if 'ssh_known_hosts_file' in self.config: - self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) - if isinstance(self._connect_params['known_hosts'], str): - known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) + if "ssh_known_hosts_file" in self.config: + self._connect_params["known_hosts"] = self.config.get( + "ssh_known_hosts_file", None + ) + if isinstance(self._connect_params["known_hosts"], str): + known_hosts_file = os.path.expanduser( + self._connect_params["known_hosts"] + ) if not os.path.exists(known_hosts_file): - raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") - self._connect_params['known_hosts'] = known_hosts_file - if self._connect_params['known_hosts'] is None: + raise ValueError( + f"ssh_known_hosts_file {known_hosts_file} does not exist" + ) + self._connect_params["known_hosts"] = known_hosts_file + if self._connect_params["known_hosts"] is None: _LOG.info("%s known_hosts checking is disabled per config.", self) - if 'ssh_keepalive_interval' in self.config: - keepalive_internal = self.config.get('ssh_keepalive_interval') - self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) + if "ssh_keepalive_interval" in self.config: + keepalive_internal = self.config.get("ssh_keepalive_interval") + self._connect_params["keepalive_interval"] = nullable( + int, keepalive_internal + ) def _enter_context(self) -> "SshService": # Start the background thread if it's not already running. @@ -277,9 +333,12 @@ def _enter_context(self) -> "SshService": super()._enter_context() return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: # Stop the background thread if it's not needed anymore and potentially # cleanup the cache as well. assert self._in_context @@ -295,7 +354,9 @@ def clear_client_cache(cls) -> None: """ cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() - def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: + def _run_coroutine( + self, coro: Coroutine[Any, Any, CoroReturnType] + ) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread. @@ -334,28 +395,32 @@ def _get_connect_params(self, params: dict) -> dict: # Start with the base config params. connect_params = self._connect_params.copy() - connect_params['host'] = params['ssh_hostname'] # required + connect_params["host"] = params["ssh_hostname"] # required - if params.get('ssh_port'): - connect_params['port'] = int(params.pop('ssh_port')) - elif self.config['ssh_port']: - connect_params['port'] = int(self.config['ssh_port']) + if params.get("ssh_port"): + connect_params["port"] = int(params.pop("ssh_port")) + elif self.config["ssh_port"]: + connect_params["port"] = int(self.config["ssh_port"]) - if 'ssh_username' in params: - connect_params['username'] = str(params.pop('ssh_username')) - elif self.config['ssh_username']: - connect_params['username'] = str(self.config['ssh_username']) + if "ssh_username" in params: + connect_params["username"] = str(params.pop("ssh_username")) + elif self.config["ssh_username"]: + connect_params["username"] = str(self.config["ssh_username"]) - priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) + priv_key_file: Optional[str] = params.get( + "ssh_priv_key_path", self.config["ssh_priv_key_path"] + ) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) if not os.path.exists(priv_key_file): raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") - connect_params['client_keys'] = [priv_key_file] + connect_params["client_keys"] = [priv_key_file] return connect_params - async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: + async def _get_client_connection( + self, params: dict + ) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) SshClient (connection) for the given connection params. @@ -370,4 +435,8 @@ async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnectio The connection and client objects. """ assert self._in_context - return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) + return ( + await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( + self._get_connect_params(params) + ) + ) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 725d0c3306..02bb06e755 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -18,12 +18,12 @@ from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec __all__ = [ - 'SupportsAuth', - 'SupportsConfigLoading', - 'SupportsFileShareOps', - 'SupportsHostProvisioning', - 'SupportsLocalExec', - 'SupportsNetworkProvisioning', - 'SupportsRemoteConfig', - 'SupportsRemoteExec', + "SupportsAuth", + "SupportsConfigLoading", + "SupportsFileShareOps", + "SupportsHostProvisioning", + "SupportsLocalExec", + "SupportsNetworkProvisioning", + "SupportsRemoteConfig", + "SupportsRemoteExec", ] diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 05853da0a9..04d1c44ca9 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -34,8 +34,9 @@ class SupportsConfigLoading(Protocol): Protocol interface for helper functions to lookup and load configs. """ - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path( + self, file_path: str, extra_paths: Optional[Iterable[str]] = None + ) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -53,7 +54,9 @@ def resolve_path(self, file_path: str, An actual path to the config or script. """ - def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: + def load_config( + self, json_file_name: str, schema_type: Optional[ConfigSchema] + ) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -72,12 +75,14 @@ def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) Free-format dictionary that contains the configuration. """ - def build_environment(self, # pylint: disable=too-many-arguments - config: dict, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> "Environment": + def build_environment( + self, # pylint: disable=too-many-arguments + config: dict, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -107,12 +112,13 @@ def build_environment(self, # pylint: disable=too-many-arguments """ def load_environment_list( # pylint: disable=too-many-arguments - self, - json_file_name: str, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> List["Environment"]: + self, + json_file_name: str, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> List["Environment"]: """ Load and build a list of environments from the config file. @@ -137,9 +143,12 @@ def load_environment_list( # pylint: disable=too-many-arguments A list of new benchmarking environments. """ - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Read the configuration files and bundle all service methods from those configs into a single Service object. diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 87ec9e49da..8252dc17ed 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -15,7 +15,9 @@ class SupportsFileShareOps(Protocol): Protocol interface for file share operations. """ - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -33,7 +35,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 77b481e48e..31f1eb8097 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -36,7 +36,9 @@ def provision_host(self, params: dict) -> Tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ - def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status", dict]: + def wait_host_deployment( + self, params: dict, *, is_setup: bool + ) -> Tuple["Status", dict]: """ Waits for a pending operation on a Host/VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index c4c5f01ddc..126966c713 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -32,9 +32,12 @@ class SupportsLocalExec(Protocol): Used in LocalEnv and provided by LocalExecService. """ - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, TunableValue]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -55,7 +58,9 @@ def local_exec(self, script_lines: Iterable[str], A 3-tuple of return code, stdout, and stderr of the script process. """ - def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index fb753aa21c..5ce5ebb8e4 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -36,7 +36,9 @@ def provision_network(self, params: dict) -> Tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ - def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status", dict]: + def wait_network_deployment( + self, params: dict, *, is_setup: bool + ) -> Tuple["Status", dict]: """ Waits for a pending operation on a Network to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -56,7 +58,9 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ - def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]: + def deprovision_network( + self, params: dict, ignore_errors: bool = True + ) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index c653e10c2b..8a414fad8e 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -18,8 +18,9 @@ class SupportsRemoteConfig(Protocol): Protocol interface for configuring cloud services. """ - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple["Status", dict]: + def configure( + self, config: Dict[str, Any], params: Dict[str, Any] + ) -> Tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index 096cb3c675..f6ca57912a 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -20,8 +20,9 @@ class SupportsRemoteExec(Protocol): scripts on a remote host OS. """ - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Run a command on remote host OS. diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index 9ae5c80f36..0812270747 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.storage.storage_factory import from_config __all__ = [ - 'Storage', - 'from_config', + "Storage", + "from_config", ] diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index ce07e44e2b..47581f0725 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -32,12 +32,15 @@ class ExperimentData(metaclass=ABCMeta): RESULT_COLUMN_PREFIX = "result." CONFIG_COLUMN_PREFIX = "config." - def __init__(self, *, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): self._experiment_id = experiment_id self._description = description self._root_env_config = root_env_config @@ -142,9 +145,9 @@ def default_tunable_config_id(self) -> Optional[int]: trials_items = sorted(self.trials.items()) if not trials_items: return None - for (_trial_id, trial) in trials_items: + for _trial_id, trial in trials_items: # Take the first config id marked as "defaults" when it was instantiated. - if strtobool(str(trial.metadata_dict.get('is_defaults', False))): + if strtobool(str(trial.metadata_dict.get("is_defaults", False))): return trial.tunable_config_id # Fallback (min trial_id) return trials_items[0][1].tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 2165fa706f..8167504627 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -30,10 +30,12 @@ class Storage(metaclass=ABCMeta): and storage systems (e.g., SQLite or MLFLow). """ - def __init__(self, - config: Dict[str, Any], - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + config: Dict[str, Any], + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new storage object. @@ -74,13 +76,16 @@ def experiments(self) -> Dict[str, ExperimentData]: """ @abstractmethod - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment': + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> "Storage.Experiment": """ Create a new experiment in the storage. @@ -116,23 +121,27 @@ class Experiment(metaclass=ABCMeta): This class is instantiated in the `Storage.experiment()` method. """ - def __init__(self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) + (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( + root_env_config + ) self._description = description self._opt_targets = opt_targets self._in_context = False - def __enter__(self) -> 'Storage.Experiment': + def __enter__(self) -> "Storage.Experiment": """ Enter the context of the experiment. @@ -144,9 +153,12 @@ def __enter__(self) -> 'Storage.Experiment': self._in_context = True return self - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: """ End the context of the experiment. @@ -157,8 +169,11 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], _LOG.debug("Finishing experiment: %s", self) else: assert exc_type and exc_val - _LOG.warning("Finishing experiment: %s", self, - exc_info=(exc_type, exc_val, exc_tb)) + _LOG.warning( + "Finishing experiment: %s", + self, + exc_info=(exc_type, exc_val, exc_tb), + ) assert self._in_context self._teardown(is_ok) self._in_context = False @@ -248,8 +263,10 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: """ @abstractmethod - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. @@ -269,7 +286,9 @@ def load(self, last_trial_id: int = -1, """ @abstractmethod - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: + def pending_trials( + self, timestamp: datetime, *, running: bool + ) -> Iterator["Storage.Trial"]: """ Return an iterator over the pending trials that are scheduled to run on or before the specified timestamp. @@ -289,8 +308,12 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Sto """ @abstractmethod - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "Storage.Trial": """ Create a new experiment run in the storage. @@ -317,10 +340,16 @@ class Trial(metaclass=ABCMeta): This class is instantiated in the `Storage.Experiment.trial()` method. """ - def __init__(self, *, - tunables: TunableGroups, experiment_id: str, trial_id: int, - tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): self._tunables = tunables self._experiment_id = experiment_id self._trial_id = trial_id @@ -361,7 +390,9 @@ def tunables(self) -> TunableGroups: """ return self._tunables - def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def config( + self, global_config: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """ Produce a copy of the global configuration updated with the parameters of the current trial. @@ -378,9 +409,12 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An return config @abstractmethod - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, + status: Status, + timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. @@ -404,14 +438,21 @@ def update(self, status: Status, timestamp: datetime, assert metrics is not None opt_targets = set(self._opt_targets.keys()) if not opt_targets.issubset(metrics.keys()): - _LOG.warning("Trial %s :: opt.targets missing: %s", - self, opt_targets.difference(metrics.keys())) + _LOG.warning( + "Trial %s :: opt.targets missing: %s", + self, + opt_targets.difference(metrics.keys()), + ) # raise ValueError() return metrics @abstractmethod - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, + status: Status, + timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]], + ) -> None: """ Save the experiment's telemetry data and intermediate status. @@ -424,4 +465,6 @@ def update_telemetry(self, status: Status, timestamp: datetime, metrics : List[Tuple[datetime, str, Any]] Telemetry data. """ - _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics)) + _LOG.info( + "Store telemetry: %s :: %s %d records", self, status, len(metrics) + ) diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index b3b2bed86a..cc4eebf9df 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -31,18 +31,23 @@ class TrialData(metaclass=ABCMeta): of tunable parameters). """ - def __init__(self, *, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id assert ts_start.tzinfo == UTC, "ts_start must be in UTC" - assert ts_end is None or ts_end.tzinfo == UTC, "ts_end must be in UTC if not None" + assert ( + ts_end is None or ts_end.tzinfo == UTC + ), "ts_end must be in UTC if not None" self._ts_start = ts_start self._ts_end = ts_end self._status = status @@ -53,7 +58,10 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return self._experiment_id == other._experiment_id and self._trial_id == other._trial_id + return ( + self._experiment_id == other._experiment_id + and self._trial_id == other._trial_id + ) @property def experiment_id(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0dce110b1b..0c9adce22d 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -21,8 +21,7 @@ class TunableConfigData(metaclass=ABCMeta): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - tunable_config_id: int): + def __init__(self, *, tunable_config_id: int): self._tunable_config_id = tunable_config_id def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 18c50035a9..6cabaaf3ba 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -27,14 +27,19 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): self._experiment_id = experiment_id self._tunable_config_id = tunable_config_id # can be lazily initialized as necessary: - self._tunable_config_trial_group_id: Optional[int] = tunable_config_trial_group_id + self._tunable_config_trial_group_id: Optional[int] = ( + tunable_config_trial_group_id + ) @property def experiment_id(self) -> str: @@ -67,7 +72,9 @@ def tunable_config_trial_group_id(self) -> int: config_id. """ if self._tunable_config_trial_group_id is None: - self._tunable_config_trial_group_id = self._get_tunable_config_trial_group_id() + self._tunable_config_trial_group_id = ( + self._get_tunable_config_trial_group_id() + ) assert self._tunable_config_trial_group_id is not None return self._tunable_config_trial_group_id @@ -77,7 +84,10 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id + return ( + self._tunable_config_id == other._tunable_config_id + and self._experiment_id == other._experiment_id + ) @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index 735e21bcaf..cf09b9aa5a 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -8,5 +8,5 @@ from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ - 'SqlStorage', + "SqlStorage", ] diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index c7ee73a3bc..50d944150b 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -18,10 +18,11 @@ def get_trials( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None, +) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -30,13 +31,18 @@ def get_trials( from mlos_bench.storage.sql.trial_data import ( TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import ) + with engine.connect() as conn: # Build up sql a statement for fetching trials. - stmt = schema.trial.select().where( - schema.trial.c.exp_id == experiment_id, - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + stmt = ( + schema.trial.select() + .where( + schema.trial.c.exp_id == experiment_id, + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -60,10 +66,11 @@ def get_trials( def get_results_df( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> pandas.DataFrame: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None, +) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -72,15 +79,22 @@ def get_results_df( # pylint: disable=too-many-locals with engine.connect() as conn: # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. - tunable_config_group_id_stmt = schema.trial.select().with_only_columns( - schema.trial.c.exp_id, - schema.trial.c.config_id, - func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'), - ).where( - schema.trial.c.exp_id == experiment_id, - ).group_by( - schema.trial.c.exp_id, - schema.trial.c.config_id, + tunable_config_group_id_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.exp_id, + schema.trial.c.config_id, + func.min(schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .group_by( + schema.trial.c.exp_id, + schema.trial.c.config_id, + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -90,18 +104,24 @@ def get_results_df( tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() # Get each trial's metadata. - cur_trials_stmt = select( - schema.trial, - tunable_config_trial_group_id_subquery, - ).where( - schema.trial.c.exp_id == experiment_id, - and_( - tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, - tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, - ), - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + cur_trials_stmt = ( + select( + schema.trial, + tunable_config_trial_group_id_subquery, + ) + .where( + schema.trial.c.exp_id == experiment_id, + and_( + tunable_config_trial_group_id_subquery.c.exp_id + == schema.trial.c.exp_id, + tunable_config_trial_group_id_subquery.c.config_id + == schema.trial.c.config_id, + ), + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -110,39 +130,48 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [( - row.trial_id, - utcify_timestamp(row.ts_start, origin="utc"), - utcify_nullable_timestamp(row.ts_end, origin="utc"), - row.config_id, - row.tunable_config_trial_group_id, - row.status, - ) for row in cur_trials.fetchall()], + [ + ( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) + for row in cur_trials.fetchall() + ], columns=[ - 'trial_id', - 'ts_start', - 'ts_end', - 'tunable_config_id', - 'tunable_config_trial_group_id', - 'status', - ] + "trial_id", + "ts_start", + "ts_end", + "tunable_config_id", + "tunable_config_trial_group_id", + "status", + ], ) # Get each trial's config in wide format. - configs_stmt = schema.trial.select().with_only_columns( - schema.trial.c.trial_id, - schema.trial.c.config_id, - schema.config_param.c.param_id, - schema.config_param.c.param_value, - ).where( - schema.trial.c.exp_id == experiment_id, - ).join( - schema.config_param, - schema.config_param.c.config_id == schema.trial.c.config_id, - isouter=True - ).order_by( - schema.trial.c.trial_id, - schema.config_param.c.param_id, + configs_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.trial_id, + schema.trial.c.config_id, + schema.config_param.c.param_id, + schema.config_param.c.param_value, + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .join( + schema.config_param, + schema.config_param.c.config_id == schema.trial.c.config_id, + isouter=True, + ) + .order_by( + schema.trial.c.trial_id, + schema.config_param.c.param_id, + ) ) if tunable_config_id is not None: configs_stmt = configs_stmt.where( @@ -150,41 +179,67 @@ def get_results_df( ) configs = conn.execute(configs_stmt) configs_df = pandas.DataFrame( - [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value) - for row in configs.fetchall()], - columns=['trial_id', 'tunable_config_id', 'param', 'value'] + [ + ( + row.trial_id, + row.config_id, + ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, + row.param_value, + ) + for row in configs.fetchall() + ], + columns=["trial_id", "tunable_config_id", "param", "value"], ).pivot( - index=["trial_id", "tunable_config_id"], columns="param", values="value", + index=["trial_id", "tunable_config_id"], + columns="param", + values="value", ) - configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp) + configs_df = configs_df.apply(pandas.to_numeric, errors="coerce").fillna(configs_df) # type: ignore[assignment] # (fp) # Get each trial's results in wide format. - results_stmt = schema.trial_result.select().with_only_columns( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - schema.trial_result.c.metric_value, - ).where( - schema.trial_result.c.exp_id == experiment_id, - ).order_by( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, + results_stmt = ( + schema.trial_result.select() + .with_only_columns( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + schema.trial_result.c.metric_value, + ) + .where( + schema.trial_result.c.exp_id == experiment_id, + ) + .order_by( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + ) ) if tunable_config_id is not None: - results_stmt = results_stmt.join(schema.trial, and_( - schema.trial.c.exp_id == schema.trial_result.c.exp_id, - schema.trial.c.trial_id == schema.trial_result.c.trial_id, - schema.trial.c.config_id == tunable_config_id, - )) + results_stmt = results_stmt.join( + schema.trial, + and_( + schema.trial.c.exp_id == schema.trial_result.c.exp_id, + schema.trial.c.trial_id == schema.trial_result.c.trial_id, + schema.trial.c.config_id == tunable_config_id, + ), + ) results = conn.execute(results_stmt) results_df = pandas.DataFrame( - [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value) - for row in results.fetchall()], - columns=['trial_id', 'metric', 'value'] + [ + ( + row.trial_id, + ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, + row.metric_value, + ) + for row in results.fetchall() + ], + columns=["trial_id", "metric", "value"], ).pivot( - index="trial_id", columns="metric", values="value", + index="trial_id", + columns="metric", + values="value", ) - results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp) + results_df = results_df.apply(pandas.to_numeric, errors="coerce").fillna(results_df) # type: ignore[assignment] # (fp) # Concat the trials, configs, and results. - return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \ - .merge(results_df, on="trial_id", how="left") + return trials_df.merge( + configs_df, on=["trial_id", "tunable_config_id"], how="left" + ).merge(results_df, on="trial_id", how="left") diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 58ee3dddb5..e231188f71 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -29,15 +29,18 @@ class Experiment(Storage.Experiment): Logic for retrieving and storing the results of a single experiment. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -55,18 +58,22 @@ def _setup(self) -> None: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable exp_info = conn.execute( - self._schema.experiment.select().with_only_columns( + self._schema.experiment.select() + .with_only_columns( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, func.max(self._schema.trial.c.trial_id).label("trial_id"), - ).join( + ) + .join( self._schema.trial, self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, - isouter=True - ).where( + isouter=True, + ) + .where( self._schema.experiment.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, @@ -75,33 +82,47 @@ def _setup(self) -> None: if exp_info is None: _LOG.info("Start new experiment: %s", self._experiment_id) # It's a new experiment: create a record for it in the database. - conn.execute(self._schema.experiment.insert().values( - exp_id=self._experiment_id, - description=self._description, - git_repo=self._git_repo, - git_commit=self._git_commit, - root_env_config=self._root_env_config, - )) - conn.execute(self._schema.objectives.insert().values([ - { - "exp_id": self._experiment_id, - "optimization_target": opt_target, - "optimization_direction": opt_dir, - } - for (opt_target, opt_dir) in self.opt_targets.items() - ])) + conn.execute( + self._schema.experiment.insert().values( + exp_id=self._experiment_id, + description=self._description, + git_repo=self._git_repo, + git_commit=self._git_commit, + root_env_config=self._root_env_config, + ) + ) + conn.execute( + self._schema.objectives.insert().values( + [ + { + "exp_id": self._experiment_id, + "optimization_target": opt_target, + "optimization_direction": opt_dir, + } + for (opt_target, opt_dir) in self.opt_targets.items() + ] + ) + ) else: if exp_info.trial_id is not None: self._trial_id = exp_info.trial_id + 1 - _LOG.info("Continue experiment: %s last trial: %s resume from: %d", - self._experiment_id, exp_info.trial_id, self._trial_id) + _LOG.info( + "Continue experiment: %s last trial: %s resume from: %d", + self._experiment_id, + exp_info.trial_id, + self._trial_id, + ) # TODO: Sanity check that certain critical configs (e.g., # objectives) haven't changed to be incompatible such that a new # experiment should be started (possibly by prewarming with the # previous one). if exp_info.git_commit != self._git_commit: - _LOG.warning("Experiment %s git expected: %s %s", - self, exp_info.git_repo, exp_info.git_commit) + _LOG.warning( + "Experiment %s git expected: %s %s", + self, + exp_info.git_repo, + exp_info.git_commit, + ) def merge(self, experiment_ids: List[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) @@ -109,38 +130,55 @@ def merge(self, experiment_ids: List[str]) -> None: def load_tunable_config(self, config_id: int) -> Dict[str, Any]: with self._engine.connect() as conn: - return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) + return self._get_key_val( + conn, self._schema.config_param, "param", config_id=config_id + ) def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) ) # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. - return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) - for row in cur_telemetry.fetchall()] + return [ + ( + utcify_timestamp(row.ts, origin="utc"), + row.metric_id, + row.metric_value, + ) + for row in cur_telemetry.fetchall() + ] - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: with self._engine.connect() as conn: cur_trials = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.trial_id, self._schema.trial.c.config_id, self._schema.trial.c.status, - ).where( + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, - self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), - ).order_by( + self._schema.trial.c.status.in_( + ["SUCCEEDED", "FAILED", "TIMED_OUT"] + ), + ) + .order_by( self._schema.trial.c.trial_id.asc(), ) ) @@ -154,19 +192,33 @@ def load(self, last_trial_id: int = -1, stat = Status[trial.status] status.append(stat) trial_ids.append(trial.trial_id) - configs.append(self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id)) + configs.append( + self._get_key_val( + conn, + self._schema.config_param, + "param", + config_id=trial.config_id, + ) + ) if stat.is_succeeded(): - scores.append(self._get_key_val( - conn, self._schema.trial_result, "metric", - exp_id=self._experiment_id, trial_id=trial.trial_id)) + scores.append( + self._get_key_val( + conn, + self._schema.trial_result, + "metric", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) + ) else: scores.append(None) return (trial_ids, configs, scores, status) @staticmethod - def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: + def _get_key_val( + conn: Connection, table: Table, field: str, **kwargs: Any + ) -> Dict[str, Any]: """ Helper method to retrieve key-value pairs from the database. (E.g., configurations, results, and telemetry). @@ -175,49 +227,63 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D select( column(f"{field}_id"), column(f"{field}_value"), - ).select_from(table).where( - *[column(key) == val for (key, val) in kwargs.items()] ) + .select_from(table) + .where(*[column(key) == val for (key, val) in kwargs.items()]) ) # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. - return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access + return dict( + row._tuple() for row in cur_result.fetchall() + ) # pylint: disable=protected-access @staticmethod - def _save_params(conn: Connection, table: Table, - params: Dict[str, Any], **kwargs: Any) -> None: + def _save_params( + conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any + ) -> None: if not params: return - conn.execute(table.insert(), [ - { - **kwargs, - "param_id": key, - "param_value": nullable(str, val) - } - for (key, val) in params.items() - ]) + conn.execute( + table.insert(), + [ + {**kwargs, "param_id": key, "param_value": nullable(str, val)} + for (key, val) in params.items() + ], + ) - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: + def pending_trials( + self, timestamp: datetime, *, running: bool + ) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") - _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) + _LOG.info( + "Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp + ) if running: - pending_status = ['PENDING', 'READY', 'RUNNING'] + pending_status = ["PENDING", "READY", "RUNNING"] else: - pending_status = ['PENDING'] + pending_status = ["PENDING"] with self._engine.connect() as conn: - cur_trials = conn.execute(self._schema.trial.select().where( - self._schema.trial.c.exp_id == self._experiment_id, - (self._schema.trial.c.ts_start.is_(None) | - (self._schema.trial.c.ts_start <= timestamp)), - self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), - )) + cur_trials = conn.execute( + self._schema.trial.select().where( + self._schema.trial.c.exp_id == self._experiment_id, + ( + self._schema.trial.c.ts_start.is_(None) + | (self._schema.trial.c.ts_start <= timestamp) + ), + self._schema.trial.c.ts_end.is_(None), + self._schema.trial.c.status.in_(pending_status), + ) + ) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", - config_id=trial.config_id) + conn, self._schema.config_param, "param", config_id=trial.config_id + ) config = self._get_key_val( - conn, self._schema.trial_param, "param", - exp_id=self._experiment_id, trial_id=trial.trial_id) + conn, + self._schema.trial_param, + "param", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) yield Trial( engine=self._engine, schema=self._schema, @@ -235,42 +301,59 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: Get the config ID for the given tunables. If the config does not exist, create a new record for it. """ - config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() - cur_config = conn.execute(self._schema.config.select().where( - self._schema.config.c.config_hash == config_hash - )).fetchone() + config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() + cur_config = conn.execute( + self._schema.config.select().where( + self._schema.config.c.config_hash == config_hash + ) + ).fetchone() if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute(self._schema.config.insert().values( - config_hash=config_hash)).inserted_primary_key[0] + config_id: int = conn.execute( + self._schema.config.insert().values(config_hash=config_hash) + ).inserted_primary_key[0] self._save_params( - conn, self._schema.config_param, + conn, + self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, - config_id=config_id) + config_id=config_id, + ) return config_id - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> Storage.Trial: + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Storage.Trial: ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") - _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) + _LOG.debug( + "Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start + ) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) - conn.execute(self._schema.trial.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - config_id=config_id, - ts_start=ts_start, - status='PENDING', - )) + conn.execute( + self._schema.trial.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + config_id=config_id, + ts_start=ts_start, + status="PENDING", + ) + ) # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: self._save_params( - conn, self._schema.trial_param, config, - exp_id=self._experiment_id, trial_id=self._trial_id) + conn, + self._schema.trial_param, + config, + exp_id=self._experiment_id, + trial_id=self._trial_id, + ) trial = Trial( engine=self._engine, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index eaa6e1041f..a370ad1060 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -35,14 +35,17 @@ class ExperimentSqlData(ExperimentData): scripts and mlos_bench configuration files. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): super().__init__( experiment_id=experiment_id, description=description, @@ -57,9 +60,11 @@ def __init__(self, *, def objectives(self) -> Dict[str, Literal["min", "max"]]: with self._engine.connect() as conn: objectives_db_data = conn.execute( - self._schema.objectives.select().where( + self._schema.objectives.select() + .where( self._schema.objectives.c.exp_id == self._experiment_id, - ).order_by( + ) + .order_by( self._schema.objectives.c.weight.desc(), self._schema.objectives.c.optimization_target.asc(), ) @@ -80,13 +85,19 @@ def trials(self) -> Dict[int, TrialData]: def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: with self._engine.connect() as conn: tunable_config_trial_groups = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.config_id, - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label( # pylint: disable=not-callable + "tunable_config_trial_group_id" + ), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -106,11 +117,14 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: def tunable_configs(self) -> Dict[int, TunableConfigData]: with self._engine.connect() as conn: tunable_configs = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -139,20 +153,30 @@ def default_tunable_config_id(self) -> Optional[int]: """ with self._engine.connect() as conn: query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial_param.select().with_only_columns( - func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - "first_trial_id_with_defaults"), - ).where( + self._schema.trial_param.select() + .with_only_columns( + func.min(self._schema.trial_param.c.trial_id) + .cast(Integer) + .label( # pylint: disable=not-callable + "first_trial_id_with_defaults" + ), + ) + .where( self._schema.trial_param.c.exp_id == self._experiment_id, self._schema.trial_param.c.param_id == "is_defaults", - func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), - ).scalar_subquery() - ) + func.lower( + self._schema.trial_param.c.param_value, type_=String + ).in_(["1", "true"]), + ) + .scalar_subquery() + ), ) ) min_default_trial_row = query_results.fetchone() @@ -161,17 +185,24 @@ def default_tunable_config_id(self) -> Optional[int]: return min_default_trial_row._tuple()[0] # fallback logic - assume minimum trial_id for experiment query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("first_trial_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).scalar_subquery() - ) + ) + .scalar_subquery() + ), ) ) min_trial_row = query_results.fetchone() diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 9a1eca2744..abc5ab27ac 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -80,7 +80,6 @@ def __init__(self, engine: Engine): Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), - PrimaryKeyConstraint("exp_id"), ) @@ -95,20 +94,29 @@ def __init__(self, engine: Engine): # Will need to adjust the insert and return values to support this # eventually. Column("weight", Float, nullable=True), - PrimaryKeyConstraint("exp_id", "optimization_target"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ) # A workaround for SQLAlchemy issue with autoincrement in DuckDB: if engine.dialect.name == "duckdb": - seq_config_id = Sequence('seq_config_id') - col_config_id = Column("config_id", Integer, seq_config_id, - server_default=seq_config_id.next_value(), - nullable=False, primary_key=True) + seq_config_id = Sequence("seq_config_id") + col_config_id = Column( + "config_id", + Integer, + seq_config_id, + server_default=seq_config_id.next_value(), + nullable=False, + primary_key=True, + ) else: - col_config_id = Column("config_id", Integer, nullable=False, - primary_key=True, autoincrement=True) + col_config_id = Column( + "config_id", + Integer, + nullable=False, + primary_key=True, + autoincrement=True, + ) self.config = Table( "config", @@ -127,7 +135,6 @@ def __init__(self, engine: Engine): Column("ts_end", DateTime), # Should match the text IDs of `mlos_bench.environments.Status` enum: Column("status", String(self._STATUS_LEN), nullable=False), - PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -141,7 +148,6 @@ def __init__(self, engine: Engine): Column("config_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -155,10 +161,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_status = Table( @@ -168,10 +174,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), - UniqueConstraint("exp_id", "trial_id", "ts"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_result = Table( @@ -181,10 +187,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_telemetry = Table( @@ -195,15 +201,15 @@ def __init__(self, engine: Engine): Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) _LOG.debug("Schema: %s", self._meta) - def create(self) -> 'DbSchema': + def create(self) -> "DbSchema": """ Create the DB schema. """ diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index bde38575bd..a52861d3ad 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -27,10 +27,12 @@ class SqlStorage(Storage): An implementation of the Storage interface using SQLAlchemy backend. """ - def __init__(self, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) self._log_sql = self._config.pop("log_sql", False) @@ -47,7 +49,7 @@ def __init__(self, @property def _schema(self) -> DbSchema: """Lazily create schema upon first access.""" - if not hasattr(self, '_db_schema'): + if not hasattr(self, "_db_schema"): self._db_schema = DbSchema(self._engine).create() if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("DDL statements:\n%s", self._schema) @@ -56,13 +58,16 @@ def _schema(self) -> DbSchema: def __repr__(self) -> str: return self._repr - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment: + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> Storage.Experiment: return Experiment( engine=self._engine, schema=self._schema, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 7ac7958845..d730aef0aa 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -27,15 +27,18 @@ class Trial(Storage.Trial): Store the results of a single run of the experiment in SQL database. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - config_id: int, - opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -47,9 +50,12 @@ def __init__(self, *, self._engine = engine self._schema = schema - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, + status: Status, + timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) @@ -59,13 +65,16 @@ def update(self, status: Status, timestamp: datetime, if status.is_completed(): # Final update of the status and ts_end: cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] + ), + ) + .values( status=status.name, ts_end=timestamp, ) @@ -73,67 +82,96 @@ def update(self, status: Status, timestamp: datetime, if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." + - f" ({cur_status.rowcount} rows)") + f"Failed to update the status of the trial {self} to {status}." + + f" ({cur_status.rowcount} rows)" + ) if metrics: - conn.execute(self._schema.trial_result.insert().values([ - { - "exp_id": self._experiment_id, - "trial_id": self._trial_id, - "metric_id": key, - "metric_value": nullable(str, val), - } - for (key, val) in metrics.items() - ])) + conn.execute( + self._schema.trial_result.insert().values( + [ + { + "exp_id": self._experiment_id, + "trial_id": self._trial_id, + "metric_id": key, + "metric_value": nullable(str, val), + } + for (key, val) in metrics.items() + ] + ) + ) else: # Update of the status and ts_start when starting the trial: assert metrics is None, f"Unexpected metrics for status: {status}" cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + [ + "RUNNING", + "SUCCEEDED", + "CANCELED", + "FAILED", + "TIMED_OUT", + ] + ), + ) + .values( status=status.name, ts_start=timestamp, ) ) if cur_status.rowcount not in {1, -1}: # Keep the old status and timestamp if already running, but log it. - _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) + _LOG.warning( + "Trial %s :: cannot be updated to: %s", self, status + ) except Exception: conn.rollback() raise return metrics - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, + status: Status, + timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]], + ) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") - metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] + metrics = [ + (utcify_timestamp(ts, origin="local"), key, val) + for (ts, key, val) in metrics + ] # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of # a bulk upsert. # See Also: comments in with self._engine.begin() as conn: self._update_status(conn, status, timestamp) - for (metric_ts, key, val) in metrics: + for metric_ts, key, val in metrics: with self._engine.begin() as conn: try: - conn.execute(self._schema.trial_telemetry.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=metric_ts, - metric_id=key, - metric_value=nullable(str, val), - )) + conn.execute( + self._schema.trial_telemetry.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=metric_ts, + metric_id=key, + metric_value=nullable(str, val), + ) + ) except IntegrityError as ex: - _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) + _LOG.warning( + "Record already exists: %s :: %s", (metric_ts, key, val), ex + ) - def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: + def _update_status( + self, conn: Connection, status: Status, timestamp: datetime + ) -> None: """ Insert a new status record into the database. This call is idempotent. @@ -141,12 +179,18 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") try: - conn.execute(self._schema.trial_status.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=timestamp, - status=status.name, - )) + conn.execute( + self._schema.trial_status.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=timestamp, + status=status.name, + ) + ) except IntegrityError as ex: - _LOG.warning("Status with that timestamp already exists: %s %s :: %s", - self, timestamp, ex) + _LOG.warning( + "Status with that timestamp already exists: %s %s :: %s", + self, + timestamp, + ex, + ) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 5a6f8a5ee8..b5551bd856 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -29,15 +29,18 @@ class TrialSqlData(TrialData): An interface to access the trial data stored in the SQL DB. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - trial_id: int, - config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + trial_id: int, + config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): super().__init__( experiment_id=experiment_id, trial_id=trial_id, @@ -56,8 +59,11 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ - return TunableConfigSqlData(engine=self._engine, schema=self._schema, - tunable_config_id=self._tunable_config_id) + return TunableConfigSqlData( + engine=self._engine, + schema=self._schema, + tunable_config_id=self._tunable_config_id, + ) @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @@ -68,9 +74,13 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": from mlos_bench.storage.sql.tunable_config_trial_group_data import ( TunableConfigTrialGroupSqlData, ) - return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, - experiment_id=self._experiment_id, - tunable_config_id=self._tunable_config_id) + + return TunableConfigTrialGroupSqlData( + engine=self._engine, + schema=self._schema, + experiment_id=self._experiment_id, + tunable_config_id=self._tunable_config_id, + ) @property def results_df(self) -> pandas.DataFrame: @@ -79,16 +89,19 @@ def results_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_results = conn.execute( - self._schema.trial_result.select().where( + self._schema.trial_result.select() + .where( self._schema.trial_result.c.exp_id == self._experiment_id, - self._schema.trial_result.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_result.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_result.c.metric_id, ) ) return pandas.DataFrame( [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], - columns=['metric', 'value']) + columns=["metric", "value"], + ) @property def telemetry_df(self) -> pandas.DataFrame: @@ -97,10 +110,12 @@ def telemetry_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) @@ -108,8 +123,16 @@ def telemetry_df(self) -> pandas.DataFrame: # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], - columns=['ts', 'metric', 'value']) + [ + ( + utcify_timestamp(row.ts, origin="utc"), + row.metric_id, + row.metric_value, + ) + for row in cur_telemetry.fetchall() + ], + columns=["ts", "metric", "value"], + ) @property def metadata_df(self) -> pandas.DataFrame: @@ -120,13 +143,16 @@ def metadata_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_params = conn.execute( - self._schema.trial_param.select().where( + self._schema.trial_param.select() + .where( self._schema.trial_param.c.exp_id == self._experiment_id, - self._schema.trial_param.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_param.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_params.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index e484979790..2441f70b9c 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -20,10 +20,7 @@ class TunableConfigSqlData(TunableConfigData): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunable_config_id: int): + def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): super().__init__(tunable_config_id=tunable_config_id) self._engine = engine self._schema = schema @@ -32,12 +29,13 @@ def __init__(self, *, def config_df(self) -> pandas.DataFrame: with self._engine.connect() as conn: cur_config = conn.execute( - self._schema.config_param.select().where( - self._schema.config_param.c.config_id == self._tunable_config_id - ).order_by( + self._schema.config_param.select() + .where(self._schema.config_param.c.config_id == self._tunable_config_id) + .order_by( self._schema.config_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_config.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index eb389a5940..4c3882c9a0 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -33,12 +33,15 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): super().__init__( experiment_id=experiment_id, tunable_config_id=tunable_config_id, @@ -53,20 +56,28 @@ def _get_tunable_config_trial_group_id(self) -> int: """ with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label( # pylint: disable=not-callable + "tunable_config_trial_group_id" + ), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.config_id == self._tunable_config_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[ + 0 + ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy @property def tunable_config(self) -> TunableConfigData: @@ -86,8 +97,12 @@ def trials(self) -> Dict[int, "TrialData"]: trials : Dict[int, TrialData] A dictionary of the trials' data, keyed by trial id. """ - return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_trials( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) @property def results_df(self) -> pandas.DataFrame: - return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_results_df( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 220f3d812c..22e629fc82 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -13,9 +13,9 @@ from mlos_bench.storage.base_storage import Storage -def from_config(config_file: str, - global_configs: Optional[List[str]] = None, - **kwargs: Any) -> Storage: +def from_config( + config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any +) -> Storage: """ Create a new storage object from JSON5 config file. @@ -36,7 +36,7 @@ def from_config(config_file: str, config_path: List[str] = kwargs.get("config_path", []) config_loader = ConfigPersistenceService({"config_path": config_path}) global_config = {} - for fname in (global_configs or []): + for fname in global_configs or []: config = config_loader.load_config(fname, ConfigSchema.GLOBALS) global_config.update(config) config_path += config.get("config_path", []) diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index a4610da8de..64cc6c953e 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -25,16 +25,22 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where 'parameter' is a string and 'value' is some TunableValue or None. """ - if dataframe.columns.tolist() == ['metric', 'value']: + if dataframe.columns.tolist() == ["metric", "value"]: dataframe = dataframe.copy() - dataframe.rename(columns={'metric': 'parameter'}, inplace=True) - assert dataframe.columns.tolist() == ['parameter', 'value'] + dataframe.rename(columns={"metric": "parameter"}, inplace=True) + assert dataframe.columns.tolist() == ["parameter", "value"] data = {} - for _, row in dataframe.astype('O').iterrows(): - if not isinstance(row['value'], TunableValueTypeTuple): - raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}") - assert isinstance(row['parameter'], str) - if row['parameter'] in data: + for _, row in dataframe.astype("O").iterrows(): + if not isinstance(row["value"], TunableValueTypeTuple): + raise TypeError( + f"Invalid column type: {type(row['value'])} value: {row['value']}" + ) + assert isinstance(row["parameter"], str) + if row["parameter"] in data: raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe") - data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value'] + data[row["parameter"]] = ( + try_parse_val(row["value"]) + if isinstance(row["value"], str) + else row["value"] + ) return data diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 26aa142441..a3d53a38db 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -29,26 +29,36 @@ None, ] ZONE_INFO: List[Optional[tzinfo]] = [ - nullable(pytz.timezone, zone_name) - for zone_name in ZONE_NAMES + nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES ] # A decorator for tests that require docker. # Use with @requires_docker above a test_...() function. -DOCKER = shutil.which('docker') +DOCKER = shutil.which("docker") if DOCKER: - cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) + cmd = run( + "docker builder inspect default || docker buildx inspect default", + shell=True, + check=False, + capture_output=True, + ) stdout = cmd.stdout.decode() - if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): + if cmd.returncode != 0 or not any( + line for line in stdout.splitlines() if "Platform" in line and "linux" in line + ): debug("Docker is available but missing support for targeting linux platform.") DOCKER = None -requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') +requires_docker = pytest.mark.skipif( + not DOCKER, reason="Docker with Linux support is not available on this system." +) # A decorator for tests that require ssh. # Use with @requires_ssh above a test_...() function. -SSH = shutil.which('ssh') -requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') +SSH = shutil.which("ssh") +requires_ssh = pytest.mark.skipif( + not SSH, reason="ssh is not available on this system." +) # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -131,10 +141,18 @@ def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) - if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: - warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") + if ( + len(dirs_cmp.left_only) > 0 + or len(dirs_cmp.right_only) > 0 + or len(dirs_cmp.funny_files) > 0 + ): + warning( + f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" + ) return False - (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) + (_, mismatch, errors) = filecmp.cmpfiles( + dir1, dir2, dirs_cmp.common_files, shallow=False + ) if len(mismatch) > 0 or len(errors) > 0: warning(f"Found differences in files:\n{mismatch}\n{errors}") return False diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 4d728b4037..d6ee5583bb 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -18,12 +18,16 @@ from importlib.resources import files -BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/") +BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace( + "\\", "/" +) -def locate_config_examples(root_dir: str, - config_examples_dir: str, - examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: +def locate_config_examples( + root_dir: str, + config_examples_dir: str, + examples_filter: Optional[Callable[[List[str]], List[str]]] = None, +) -> List[str]: """Locates all config examples in the given directory. Parameters diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index e1e26d7d8b..8e20001926 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -43,7 +43,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -51,7 +53,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover +def test_load_cli_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -61,7 +65,9 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic assert isinstance(config_paths, list) config_paths.reverse() for path in config_paths: - config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access + config_loader_service._config_path.insert( + 0, path + ) # pylint: disable=protected-access # Foreach arg that references another file, see if we can at least load that too. args_to_skip = { @@ -78,27 +84,39 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic if arg == "globals": for path in config[arg]: - sub_config = config_loader_service.load_config(path, ConfigSchema.GLOBALS) + sub_config = config_loader_service.load_config( + path, ConfigSchema.GLOBALS + ) assert isinstance(sub_config, dict) elif arg == "environment": - sub_config = config_loader_service.load_config(config[arg], ConfigSchema.ENVIRONMENT) + sub_config = config_loader_service.load_config( + config[arg], ConfigSchema.ENVIRONMENT + ) assert isinstance(sub_config, dict) elif arg == "optimizer": - sub_config = config_loader_service.load_config(config[arg], ConfigSchema.OPTIMIZER) + sub_config = config_loader_service.load_config( + config[arg], ConfigSchema.OPTIMIZER + ) assert isinstance(sub_config, dict) elif arg == "storage": - sub_config = config_loader_service.load_config(config[arg], ConfigSchema.STORAGE) + sub_config = config_loader_service.load_config( + config[arg], ConfigSchema.STORAGE + ) assert isinstance(sub_config, dict) elif arg == "tunable_values": for path in config[arg]: - sub_config = config_loader_service.load_config(path, ConfigSchema.TUNABLE_VALUES) + sub_config = config_loader_service.load_config( + path, ConfigSchema.TUNABLE_VALUES + ) assert isinstance(sub_config, dict) else: raise NotImplementedError(f"Unhandled arg {arg} in config {config_path}") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_cli_config_examples_via_launcher( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) @@ -106,29 +124,38 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Try to load the CLI config by instantiating a launcher. # To do this we need to make sure to give it a few extra paths and globals # to look for for our examples. - cli_args = f"--config {config_path}" + \ - f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \ - f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \ - f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" - launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split()) + cli_args = ( + f"--config {config_path}" + + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" + ) + launcher = Launcher( + description=__name__, long_text=config_path, argv=cli_args.split() + ) assert launcher # Check that some parts of that config are loaded. - assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in launcher.config_loader.config_paths + assert ( + ConfigPersistenceService.BUILTIN_CONFIG_PATH + in launcher.config_loader.config_paths + ) if config_paths := config.get("config_path"): assert isinstance(config_paths, list) for path in config_paths: # Note: Checks that the order is maintained are handled in launcher_parse_args.py - assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \ - f"Expected {path} to be in {launcher.config_loader.config_paths}" + assert any( + config_path.endswith(path) + for config_path in launcher.config_loader.config_paths + ), f"Expected {path} to be in {launcher.config_loader.config_paths}" - if 'experiment_id' in config: - assert launcher.global_config['experiment_id'] == config['experiment_id'] - if 'trial_id' in config: - assert launcher.global_config['trial_id'] == config['trial_id'] + if "experiment_id" in config: + assert launcher.global_config["experiment_id"] == config["experiment_id"] + if "trial_id" in config: + assert launcher.global_config["trial_id"] == config["trial_id"] - expected_log_level = logging.getLevelName(config.get('log_level', "INFO")) + expected_log_level = logging.getLevelName(config.get("log_level", "INFO")) if isinstance(expected_log_level, int): expected_log_level = logging.getLevelName(expected_log_level) current_log_level = logging.getLevelName(logging.root.getEffectiveLevel()) @@ -136,7 +163,7 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # TODO: Check that the log_file handler is set correctly. - expected_teardown = config.get('teardown', True) + expected_teardown = config.get("teardown", True) assert launcher.teardown == expected_teardown # Note: Testing of "globals" processing handled in launcher_parse_args_test.py @@ -145,22 +172,30 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Launcher loaded the expected types as well. assert isinstance(launcher.environment, Environment) - env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT) + env_config = launcher.config_loader.load_config( + config["environment"], ConfigSchema.ENVIRONMENT + ) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: - opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER) + opt_config = launcher.config_loader.load_config( + config["optimizer"], ConfigSchema.OPTIMIZER + ) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: - storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE) + storage_config = launcher.config_loader.load_config( + config["storage"], ConfigSchema.STORAGE + ) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: - scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER) + scheduler_config = launcher.config_loader.load_config( + config["scheduler"], ConfigSchema.SCHEDULER + ) assert check_class_name(launcher.scheduler, scheduler_config["class"]) # TODO: Check that the launcher assigns the tunables values as expected. diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index fdcb3370cf..2c3932a128 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -22,9 +22,11 @@ @pytest.fixture def config_loader_service() -> ConfigPersistenceService: """Config loader service fixture.""" - return ConfigPersistenceService(config={ - "config_path": [ - str(files("mlos_bench.tests.config")), - path_join(str(files("mlos_bench.tests.config")), "globals"), - ] - }) + return ConfigPersistenceService( + config={ + "config_path": [ + str(files("mlos_bench.tests.config")), + path_join(str(files("mlos_bench.tests.config")), "globals"), + ] + } + ) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 42925a0a5d..6ee34dbc71 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -27,16 +27,24 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] + configs_to_filter = [ + config_path + for config_path in configs_to_filter + if not config_path.endswith("-tunables.jsonc") + ] return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) for env in envs: @@ -44,11 +52,15 @@ def test_load_environment_config_examples(config_loader_service: ConfigPersisten assert isinstance(env, Environment) -def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: +def load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. - global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) - global_config.setdefault('trial_id', 1) # normally populated by Launcher + global_config = config_loader_service.load_config( + "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS + ) + global_config.setdefault("trial_id", 1) # normally populated by Launcher # Make sure we have the required services for the envs being used. mock_service_configs = [ @@ -60,24 +72,34 @@ def load_environment_config_examples(config_loader_service: ConfigPersistenceSer "services/remote/mock/mock_auth_service.jsonc", ] - tunable_groups = TunableGroups() # base tunable groups that all others get built on + tunable_groups = TunableGroups() # base tunable groups that all others get built on for mock_service_config_path in mock_service_configs: - mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) - config_loader_service.register(config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service).export()) + mock_service_config = config_loader_service.load_config( + mock_service_config_path, ConfigSchema.SERVICE + ) + config_loader_service.register( + config_loader_service.build_service( + config=mock_service_config, parent=config_loader_service + ).export() + ) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service) + config_path, tunable_groups, global_config, service=config_loader_service + ) return envs -composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") +composite_configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" +) assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) -def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_composite_env_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) assert len(envs) == 1 @@ -90,11 +112,15 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist assert child_env.tunable_params is not None checked_child_env_groups = set() - for (child_tunable, child_group) in child_env.tunable_params: + for child_tunable, child_group in child_env.tunable_params: # Lookup that tunable in the composite env. assert child_tunable in composite_env.tunable_params - (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) - assert child_tunable is composite_tunable # Check that the tunables are the same object. + (composite_tunable, composite_group) = ( + composite_env.tunable_params.get_tunable(child_tunable) + ) + assert ( + child_tunable is composite_tunable + ) # Check that the tunables are the same object. if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) @@ -106,10 +132,15 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist assert child_tunable.value == old_cat_value assert child_group[child_tunable] == old_cat_value assert composite_env.tunable_params[child_tunable] == old_cat_value - new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] + new_cat_value = [ + x for x in child_tunable.categories if x != old_cat_value + ][0] child_tunable.category = new_cat_value assert child_env.tunable_params[child_tunable] == new_cat_value - assert composite_env.tunable_params[child_tunable] == child_tunable.category + assert ( + composite_env.tunable_params[child_tunable] + == child_tunable.category + ) elif child_tunable.is_numerical: old_num_value = child_tunable.numerical_value assert child_tunable.value == old_num_value @@ -117,4 +148,7 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist assert composite_env.tunable_params[child_tunable] == old_num_value child_tunable.numerical_value += 1 assert child_env.tunable_params[child_tunable] == old_num_value + 1 - assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value + assert ( + composite_env.tunable_params[child_tunable] + == child_tunable.numerical_value + ) diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index 4d8c93fdff..fd53d63788 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -29,7 +29,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs + ), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs), ] @@ -37,7 +39,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) -def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_globals_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index 6cb6253dea..c504a6d50f 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -30,12 +30,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_optimizer_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index e4264003e1..54d619caf1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -34,14 +34,19 @@ def __hash__(self) -> int: # The different type of schema test cases we expect to have. -_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( - SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), - SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), -)} +_SCHEMA_TEST_TYPES = { + x.test_case_type: x + for x in ( + SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), + SchemaTestType( + test_case_type="bad", test_case_subtypes={"invalid", "unhandled"} + ), + ) +} @dataclass -class SchemaTestCaseInfo(): +class SchemaTestCaseInfo: """ Some basic info about a schema test case. """ @@ -61,15 +66,22 @@ def check_schema_dir_layout(test_cases_root: str) -> None: any extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): - if test_case_dir == 'README.md': + if test_case_dir == "README.md": continue if test_case_dir not in _SCHEMA_TEST_TYPES: raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") - for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): - if test_case_subdir == 'README.md': + for test_case_subdir in os.listdir( + os.path.join(test_cases_root, test_case_dir) + ): + if test_case_subdir == "README.md": continue - if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: - raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") + if ( + test_case_subdir + not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes + ): + raise NotImplementedError( + f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}" + ) @dataclass @@ -87,15 +99,23 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: """ Gets a dict of schema test cases from the given root. """ - test_cases = TestCases(by_path={}, - by_type={x: {} for x in _SCHEMA_TEST_TYPES}, - by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) + test_cases = TestCases( + by_path={}, + by_type={x: {} for x in _SCHEMA_TEST_TYPES}, + by_subtype={ + y: {} + for x in _SCHEMA_TEST_TYPES + for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes + }, + ) check_schema_dir_layout(test_cases_root) # Note: we sort the test cases so that we can deterministically test them in parallel. - for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): + for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): for test_case_subtype in schema_test_type.test_case_subtypes: - for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): - with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: + for test_case_file in locate_config_examples( + test_cases_root, os.path.join(test_case_type, test_case_subtype) + ): + with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: try: test_case_info = SchemaTestCaseInfo( config=json5.load(test_case_fh), @@ -103,11 +123,19 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: test_case_type=test_case_type, test_case_subtype=test_case_subtype, ) - test_cases.by_path[test_case_info.test_case_file] = test_case_info - test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info - test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info + test_cases.by_path[test_case_info.test_case_file] = ( + test_case_info + ) + test_cases.by_type[test_case_info.test_case_type][ + test_case_info.test_case_file + ] = test_case_info + test_cases.by_subtype[test_case_info.test_case_subtype][ + test_case_info.test_case_file + ] = test_case_info except Exception as ex: - raise RuntimeError("Failed to load test case: " + test_case_file) from ex + raise RuntimeError( + "Failed to load test case: " + test_case_file + ) from ex assert test_cases assert len(test_cases.by_type["good"]) > 0 @@ -117,7 +145,9 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: return test_cases -def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_against_schema( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """ Checks the given test case against the given schema. @@ -142,7 +172,9 @@ def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: C raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") -def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_config_with_extra_param( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """ Checks that the config fails to validate if extra params are present in certain places. """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 5dd1666008..a3401baf7f 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -26,6 +26,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: """ @@ -36,7 +37,9 @@ def test_cli_configs_against_schema(test_case_name: str) -> None: # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -44,9 +47,13 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the cli config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI + ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index dc3cd40425..efb9e8019d 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,23 +33,29 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. } -expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_ENV_CLASSES] +expected_environment_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_ENV_CLASSES +] assert expected_environment_class_names COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ -expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names - if subclass_name != COMPOSITE_ENV_CLASS_NAME] +expected_leaf_environment_class_names = [ + subclass_name + for subclass_name in expected_environment_class_names + if subclass_name != COMPOSITE_ENV_CLASS_NAME +] # Do the full cross product of all the test cases and all the Environment types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("env_class", expected_environment_class_names) -def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None: +def test_case_coverage_mlos_bench_environment_type( + test_case_subtype: str, env_class: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench Environment type. """ @@ -57,18 +63,24 @@ def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_c if try_resolve_class_name(test_case.config.get("class")) == env_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") + f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: """ Checks that the environment config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -76,5 +88,9 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the environment config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 5045bf510b..7cf497695b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -25,14 +25,19 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: """ Checks that the CLI config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS + ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index e9ee653644..6a9d43864f 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -33,9 +33,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_optimizer_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_optimizer_class_names # Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). @@ -49,16 +52,25 @@ # Do the full cross product of all the test cases and all the optimizer types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) -def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: +@pytest.mark.parametrize( + "mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names +) +def test_case_coverage_mlos_bench_optimizer_type( + test_case_subtype: str, mlos_bench_optimizer_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench optimizer type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: + if ( + try_resolve_class_name(test_case.config.get("class")) + == mlos_bench_optimizer_type + ): return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}") + f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}" + ) + # Being a little lazy for the moment and relaxing the requirement that we have # a subtype test case for each optimizer and space adapter combo. @@ -67,54 +79,77 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) -def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, - mlos_core_optimizer_type: Optional[OptimizerType]) -> None: +def test_case_coverage_mlos_core_optimizer_type( + test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] +) -> None: """ Checks to see if there is a given type of test case for the given mlos_core optimizer type. """ - optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name + optimizer_name = ( + None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name + ) for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): optimizer_type = None if test_case.config.get("config"): optimizer_type = test_case.config["config"].get("optimizer_type", None) if optimizer_type == optimizer_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}") + f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}" + ) @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) -def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, - mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: +@pytest.mark.parametrize( + "mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types +) +def test_case_coverage_mlos_core_space_adapter_type( + test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] +) -> None: """ Checks to see if there is a given type of test case for the given mlos_core space adapter type. """ - space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name + space_adapter_name = ( + None + if mlos_core_space_adapter_type is None + else mlos_core_space_adapter_type.name + ) for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): space_adapter_type = None if test_case.config.get("config"): - space_adapter_type = test_case.config["config"].get("space_adapter_type", None) + space_adapter_type = test_case.config["config"].get( + "space_adapter_type", None + ) if space_adapter_type == space_adapter_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}") + f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: """ Checks that the optimizer config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -122,5 +157,9 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the optimizer config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 8fccba8bc7..279c171a90 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -30,25 +30,37 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_scheduler_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_scheduler_class_names # Do the full cross product of all the test cases and all the scheduler types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) -def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: +@pytest.mark.parametrize( + "mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names +) +def test_case_coverage_mlos_bench_scheduler_type( + test_case_subtype: str, mlos_bench_scheduler_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench scheduler type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: + if ( + try_resolve_class_name(test_case.config.get("class")) + == mlos_bench_scheduler_type + ): return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}") + f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}" + ) + # Now we actually perform all of those validation tests. @@ -58,8 +70,12 @@ def test_scheduler_configs_against_schema(test_case_name: str) -> None: """ Checks that the scheduler config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -67,8 +83,12 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the scheduler config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 64c6fccccd..f7daf3f422 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -38,30 +38,33 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_SERVICE_CLASSES = { - ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. - AzureDeploymentService, # ABCMeta abstract base class - SshService, # ABCMeta abstract base class + ConfigPersistenceService, # configured thru the launcher cli args + TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + AzureDeploymentService, # ABCMeta abstract base class + SshService, # ABCMeta abstract base class } -expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_SERVICE_CLASSES] +expected_service_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_SERVICE_CLASSES +] assert expected_service_class_names # Do the full cross product of all the test cases and all the Service types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("service_class", expected_service_class_names) -def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None: +def test_case_coverage_mlos_bench_service_type( + test_case_subtype: str, service_class: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench Service type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] if not isinstance(test_case.config, dict): - continue # type: ignore[unreachable] + continue # type: ignore[unreachable] if "class" not in test_case.config: config_list = test_case.config["services"] else: @@ -70,18 +73,24 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c if try_resolve_class_name(config.get("class")) == service_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for service class {service_class}") + f"Missing test case for subtype {test_case_subtype} for service class {service_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: """ Checks that the service config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -89,5 +98,9 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the service config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 9b362b5e0d..640ae450f3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -28,36 +28,52 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_storage_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_storage_class_names # Do the full cross product of all the test cases and all the storage types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) -def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: +@pytest.mark.parametrize( + "mlos_bench_storage_type", expected_mlos_bench_storage_class_names +) +def test_case_coverage_mlos_bench_storage_type( + test_case_subtype: str, mlos_bench_storage_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench storage type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: + if ( + try_resolve_class_name(test_case.config.get("class")) + == mlos_bench_storage_type + ): return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}") + f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: """ Checks that the storage config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -65,9 +81,15 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the storage config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) - - -if __name__ == '__main__': - pytest.main([__file__, '-n0'],) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) + + +if __name__ == "__main__": + pytest.main( + [__file__, "-n0"], + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index a6d0de9313..cf0223d006 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -25,10 +25,15 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: """ Checks that the tunable params config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS) - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS + ) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index d871eaa212..04d2f4c709 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -25,14 +25,19 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: """ Checks that the tunable values config validates against the schema. """ - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES + ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) + check_test_case_against_schema( + TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 32034eb11c..8431251098 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -25,19 +25,27 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" + def predicate(config_path: str) -> bool: - arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc") + arm_template = config_path.find( + "services/remote/azure/arm-templates/" + ) >= 0 and config_path.endswith(".jsonc") setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0 return not (arm_template or setup_rg_scripts) + return [config_path for config_path in configs_to_filter if predicate(config_path)] -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_service_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) # Make an instance of the class based on the config. diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index 2f9773a9b0..d1d39ec4f5 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -29,12 +29,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_storage_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 58359eb983..28c83f453c 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -42,7 +42,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -59,7 +59,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score", "other_score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -82,7 +82,9 @@ def docker_compose_file(pytestconfig: pytest.Config) -> List[str]: """ _ = pytestconfig # unused return [ - os.path.join(os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml"), + os.path.join( + os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml" + ), # Add additional configs as necessary here. ] @@ -103,7 +105,9 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") -def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: +def docker_services_lock( + shared_temp_dir: str, short_testrun_uid: str +) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker services. @@ -113,11 +117,15 @@ def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterP A lock to ensure that setup/teardown operations don't happen while a worker is using the docker services. """ - return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock") + return InterProcessReaderWriterLock( + f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") -def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessLock: +def docker_setup_teardown_lock( + shared_temp_dir: str, short_testrun_uid: str +) -> InterProcessLock: """ Gets a pytest session lock between xdist workers for the docker setup/teardown operations. @@ -126,7 +134,9 @@ def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> ------ A lock to ensure that only one worker is doing setup/teardown at a time. """ - return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock") + return InterProcessLock( + f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") diff --git a/mlos_bench/mlos_bench/tests/dict_templater_test.py b/mlos_bench/mlos_bench/tests/dict_templater_test.py index 63219d9246..6604656c9a 100644 --- a/mlos_bench/mlos_bench/tests/dict_templater_test.py +++ b/mlos_bench/mlos_bench/tests/dict_templater_test.py @@ -124,7 +124,9 @@ def test_from_extras_expansion(source_template_dict: Dict[str, Any]) -> None: "extra_str": "str-from-extras", "string": "shouldn't be used", } - results = DictTemplater(source_template_dict).expand_vars(extra_source_dict=extra_source_dict) + results = DictTemplater(source_template_dict).expand_vars( + extra_source_dict=extra_source_dict + ) assert results == { "extra_str-ref": f"{extra_source_dict['extra_str']}-ref", "str": "string", diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index ac0b942167..8218577986 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -16,11 +16,13 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def check_env_success(env: Environment, - tunable_groups: TunableGroups, - expected_results: Dict[str, TunableValue], - expected_telemetry: List[Tuple[datetime, str, Any]], - global_config: Optional[dict] = None) -> None: +def check_env_success( + env: Environment, + tunable_groups: TunableGroups, + expected_results: Dict[str, TunableValue], + expected_telemetry: List[Tuple[datetime, str, Any]], + global_config: Optional[dict] = None, +) -> None: """ Set up an environment and run a test experiment there. @@ -50,7 +52,7 @@ def check_env_success(env: Environment, assert telemetry == pytest.approx(expected_telemetry, nan_ok=True) env_context.teardown() - assert not env_context._is_ready # pylint: disable=protected-access + assert not env_context._is_ready # pylint: disable=protected-access def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 8afb8e5cda..863e5aaa80 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -29,8 +29,8 @@ def test_expand_groups() -> None: Check the dollar variable expansion for tunable groups. """ assert Environment._expand_groups( - ["begin", "$list", "$empty", "$str", "end"], - _GROUPS) == ["begin", "c", "d", "efg", "end"] + ["begin", "$list", "$empty", "$str", "end"], _GROUPS + ) == ["begin", "c", "d", "efg", "end"] def test_expand_groups_empty_input() -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index 6497eb6985..f8f6d28afe 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -34,26 +34,32 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: { "name": "Env 2 :: tmp_other_2", "class": "mlos_bench.environments.mock_env.MockEnv", - "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"], + "include_services": [ + "services/local/mock/mock_local_exec_service_2.jsonc" + ], }, { "name": "Env 3 :: tmp_other_3", "class": "mlos_bench.environments.mock_env.MockEnv", - "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], - } + "include_services": [ + "services/local/mock/mock_local_exec_service_3.jsonc" + ], + }, ] }, tunables=tunable_groups, service=LocalExecService( - config={ - "temp_dir": "_test_tmp_global" - }, - parent=ConfigPersistenceService({ - "config_path": [ - path_join(os.path.dirname(__file__), "../config", abs_path=True), - ] - }) - ) + config={"temp_dir": "_test_tmp_global"}, + parent=ConfigPersistenceService( + { + "config_path": [ + path_join( + os.path.dirname(__file__), "../config", abs_path=True + ), + ] + } + ), + ), ) @@ -61,7 +67,11 @@ def test_composite_services(composite_env: CompositeEnv) -> None: """ Check that each environment gets its own instance of the services. """ - for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): + for i, path in ( + (0, "_test_tmp_global"), + (1, "_test_tmp_other_2"), + (2, "_test_tmp_other_3"), + ): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") with service.temp_dir_context() as temp_dir: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 742eaf3c79..184aad778d 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -28,7 +28,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", "someConst": "root", - "global_param": "default" + "global_param": "default", }, "children": [ { @@ -43,7 +43,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "someConst", "global_param"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Server Environment 2", @@ -53,12 +53,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vmName": "$vm_server_name", "EnvId": 2, - "global_param": "local" + "global_param": "local", }, "required_args": ["vmName"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Control Environment 3", @@ -72,15 +72,13 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "vm_server_name", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } - } - ] + }, + }, + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -90,61 +88,65 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", - "vm_server_name": "Mock Server VM" + "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: +def test_composite_env_setup( + composite_env: CompositeEnv, tunable_groups: TunableGroups +) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with composite_env as env_context: assert env_context.setup(tunable_groups) assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "mwait", # tunable_params from the parent "vm_client_name": "Mock Client VM", "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config @@ -163,7 +165,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", - "someConst": "root" + "someConst": "root", }, "children": [ { @@ -175,7 +177,12 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vmName": "$vm_client_name", "EnvId": 1, }, - "required_args": ["vmName", "EnvId", "someConst", "vm_server_name"], + "required_args": [ + "vmName", + "EnvId", + "someConst", + "vm_server_name", + ], "children": [ { "name": "Mock Client Environment 1", @@ -191,11 +198,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "EnvId", "someConst", "vm_server_name", - "global_param" + "global_param", ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], @@ -217,23 +224,24 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "class": "mlos_bench.environments.mock_env.MockEnv", "config": { "tunable_params": ["boot"], - "required_args": ["vmName", "EnvId", "vm_client_name"], + "required_args": [ + "vmName", + "EnvId", + "vm_client_name", + ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], }, }, - - ] + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -244,52 +252,56 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None """ assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: +def test_nested_composite_env_setup( + nested_composite_env: CompositeEnv, tunable_groups: TunableGroups +) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with nested_composite_env as env_context: assert env_context.setup(tunable_groups) assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", } diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index 7395aa3e15..bf3407b506 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -18,7 +18,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision"]}, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -32,7 +32,7 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -55,7 +55,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) expected_params = { "vmSize": "Standard_B4ms", @@ -80,11 +80,7 @@ def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: """ Make sure that no tunable groups are available to the environment by default. """ - env = MockEnv( - name="Test Env", - config={}, - tunables=tunable_groups - ) + env = MockEnv(name="Test Env", config={}, tunables=tunable_groups) assert env.tunable_params.get_param_values() == {} @@ -94,9 +90,7 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: when explicitly specifying an empty list of tunable_params. """ env = MockEnv( - name="Test Env", - config={"tunable_params": []}, - tunables=tunable_groups + name="Test Env", config={"tunable_params": []}, tunables=tunable_groups ) assert env.tunable_params.get_param_values() == {} @@ -114,7 +108,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == {} @@ -137,9 +131,7 @@ def test_loader_level_include() -> None: env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", "name": "Test Env", - "include_tunables": [ - "environments/os/linux/boot/linux-boot-tunables.jsonc" - ], + "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"], "config": { "tunable_params": ["linux-kernel-boot"], "const_args": { @@ -148,12 +140,14 @@ def test_loader_level_include() -> None: }, }, } - loader = ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - }) + loader = ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) env = loader.build_environment(config=env_json, tunables=TunableGroups()) expected_params = { "align_va_addr": "on", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index 5d8fc32c6b..c68d2fa7b8 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -32,14 +32,20 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L env : LocalEnv A new instance of the local environment. """ - return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups, - service=LocalExecService(parent=ConfigPersistenceService())) + return LocalEnv( + name="TestLocalEnv", + config=config, + tunables=tunable_groups, + service=LocalExecService(parent=ConfigPersistenceService()), + ) -def create_composite_local_env(tunable_groups: TunableGroups, - global_config: Dict[str, Any], - params: Dict[str, Any], - local_configs: List[Dict[str, Any]]) -> CompositeEnv: +def create_composite_local_env( + tunable_groups: TunableGroups, + global_config: Dict[str, Any], + params: Dict[str, Any], + local_configs: List[Dict[str, Any]], +) -> CompositeEnv: """ Create a CompositeEnv with several LocalEnv instances. @@ -70,7 +76,7 @@ def create_composite_local_env(tunable_groups: TunableGroups, "config": config, } for (i, config) in enumerate(local_configs) - ] + ], }, tunables=tunable_groups, global_config=global_config, diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 9bcb7aa218..c38c6bc584 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -26,7 +26,9 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_composite_env( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Produce benchmark and telemetry data in TWO local environments and combine the results. @@ -43,7 +45,7 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "") + (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "") env = create_composite_local_env( tunable_groups=tunable_groups, @@ -67,8 +69,8 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo "required_args": ["errors", "reads"], "shell_env_params": [ "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads" # const_args overridden by the global config + "errors", # Comes from the parent const_args + "reads", # const_args overridden by the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -90,9 +92,9 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes" # Comes straight from the global config + "throughput", # const_args overridden by the composite env + "score", # Comes from the local const_args + "writes", # Comes straight from the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -106,12 +108,13 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo ], "read_results_file": "output.csv", "read_telemetry_file": "telemetry.csv", - } - ] + }, + ], ) check_env_success( - env, tunable_groups, + env, + tunable_groups, expected_results={ "latency": 4.2, "throughput": 768.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index 20854b9f9e..bdcd9f885f 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -17,19 +17,23 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", - ], - "results_stdout_pattern": r"(\w+),([0-9.]+)", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", + ], + "results_stdout_pattern": r"(\w+),([0-9.]+)", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -45,19 +49,23 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern - ], - "results_stdout_pattern": r"^(\w+),([0-9.]+)$", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern + ], + "results_stdout_pattern": r"^(\w+),([0-9.]+)$", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -72,24 +80,28 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to *BOTH* stdout and a file and extract the results from both. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'stdout-msg,string'", - "echo '-------------------'", # Should be ignored - "echo 'metric,value' > output.csv", - "echo 'extra1,333' >> output.csv", - "echo 'extra2,444' >> output.csv", - "echo 'file-msg,string' >> output.csv", - ], - "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'stdout-msg,string'", + "echo '-------------------'", # Should be ignored + "echo 'metric,value' > output.csv", + "echo 'extra1,333' >> output.csv", + "echo 'extra2,444' >> output.csv", + "echo 'file-msg,string' >> output.csv", + ], + "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 35bdb39486..f620165de8 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -25,7 +25,9 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ @@ -37,25 +39,29 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,4.1' >> output.csv", - "echo 'throughput,512' >> output.csv", - "echo 'score,0.95' >> output.csv", - "echo '-------------------'", # This output does not go anywhere - "echo 'timestamp,metric,value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_results_file": "output.csv", - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,4.1' >> output.csv", + "echo 'throughput,512' >> output.csv", + "echo 'score,0.95' >> output.csv", + "echo '-------------------'", # This output does not go anywhere + "echo 'timestamp,metric,value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_results_file": "output.csv", + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 4.1, "throughput": 512.0, @@ -72,7 +78,9 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_no_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Read the telemetry data with no header. """ @@ -84,18 +92,22 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - f"echo {time_str1},cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + f"echo {time_str1},cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={}, expected_telemetry=[ (ts1.astimezone(UTC), "cpu_load", 0.65), @@ -106,9 +118,13 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: ) -@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa +@pytest.mark.filterwarnings( + "ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0" +) # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_wrong_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Read the telemetry data with incorrect header. """ @@ -120,17 +136,20 @@ def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_in time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: the data is correct, but the header has unexpected column names - "echo 'ts,metric_name,metric_value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: the data is correct, but the header has unexpected column names + "echo 'ts,metric_name,metric_value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) @@ -148,16 +167,19 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: too many columns - f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: too many columns + f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) @@ -166,15 +188,18 @@ def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: """ Fail when the telemetry data has wrong format. """ - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: field 1 must be a timestamp - "echo 1,cpu_load,0.65 > telemetry.csv", - "echo 2,mem_usage,10240 >> telemetry.csv", - "echo 3,cpu_load,0.8 >> telemetry.csv", - "echo 4,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: field 1 must be a timestamp + "echo 1,cpu_load,0.65 > telemetry.csv", + "echo 2,mem_usage,10240 >> telemetry.csv", + "echo 3,cpu_load,0.8 >> telemetry.csv", + "echo 4,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 6cb4fd4f7e..2b51ae1f0e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -16,18 +16,22 @@ def test_local_env(tunable_groups: TunableGroups) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,10' >> output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,10' >> output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10.0, "throughput": 66.0, @@ -41,9 +45,7 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ Basic check that context support for Service mixins are handled when environment contexts are entered. """ - local_env = create_local_env(tunable_groups, { - "run": ["echo NA"] - }) + local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) # pylint: disable=protected-access assert local_env._service assert not local_env._service._in_context @@ -51,10 +53,10 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: with local_env as env_context: assert env_context._in_context assert local_env._service._in_context - assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) + assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) assert all(svc._in_context for svc in local_env._service._service_contexts) assert all(svc._in_context for svc in local_env._service._services) - assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) + assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) assert not local_env._service._service_contexts assert not any(svc._in_context for svc in local_env._service._services) @@ -63,15 +65,18 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: """ Fail if the results are not in the expected format. """ - local_env = create_local_env(tunable_groups, { - "run": [ - # No header - "echo 'latency,10' > output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # No header + "echo 'latency,10' > output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) with local_env as env_context: assert env_context.setup(tunable_groups) @@ -83,16 +88,20 @@ def test_local_env_wide(tunable_groups: TunableGroups) -> None: """ Produce benchmark data in wide format and read it. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,throughput,score' > output.csv", - "echo '10,66,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,throughput,score' > output.csv", + "echo '10,66,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10, "throughput": 66, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index c16eac4459..52e15be076 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -14,31 +14,36 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None: +def _run_local_env( + tunable_groups: TunableGroups, shell_subcmd: str, expected: dict +) -> None: """ Check that LocalEnv can set shell environment variables. """ - local_env = create_local_env(tunable_groups, { - "const_args": { - "const_arg": 111, # Passed into "shell_env_params" - "other_arg": 222, # NOT passed into "shell_env_params" + local_env = create_local_env( + tunable_groups, + { + "const_args": { + "const_arg": 111, # Passed into "shell_env_params" + "other_arg": 222, # NOT passed into "shell_env_params" + }, + "tunable_params": ["kernel"], + "shell_env_params": [ + "const_arg", # From "const_arg" + "kernel_sched_latency_ns", # From "tunable_params" + ], + "run": [ + "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", + f"echo {shell_subcmd} >> output.csv", + ], + "read_results_file": "output.csv", }, - "tunable_params": ["kernel"], - "shell_env_params": [ - "const_arg", # From "const_arg" - "kernel_sched_latency_ns", # From "tunable_params" - ], - "run": [ - "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", - f"echo {shell_subcmd} >> output.csv", - ], - "read_results_file": "output.csv", - }) + ) check_env_success(local_env, tunable_groups, expected, []) -@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") +@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables in sh-like shell. @@ -47,15 +52,15 @@ def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", expected={ - "const_arg": 111, # From "const_args" - "other_arg": float("NaN"), # Not included in "shell_env_params" - "unknown_arg": float("NaN"), # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": float("NaN"), # Not included in "shell_env_params" + "unknown_arg": float("NaN"), # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) -@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") +@pytest.mark.skipif(sys.platform != "win32", reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables on Windows / cmd shell. @@ -64,9 +69,9 @@ def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%", expected={ - "const_arg": 111, # From "const_args" - "other_arg": r"%other_arg%", # Not included in "shell_env_params" - "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": r"%other_arg%", # Not included in "shell_env_params" + "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 8bce053f7b..25e75cf748 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -25,13 +25,14 @@ def mock_fileshare_service() -> MockFileShareService: """ return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, - parent=LocalExecService(parent=ConfigPersistenceService()) + parent=LocalExecService(parent=ConfigPersistenceService()), ) @pytest.fixture -def local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: +def local_fileshare_env( + tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService +) -> LocalFileShareEnv: """ Create a LocalFileShareEnv instance. """ @@ -40,12 +41,12 @@ def local_fileshare_env(tunable_groups: TunableGroups, config={ "const_args": { "experiment_id": "EXP_ID", # Passed into "shell_env_params" - "trial_id": 222, # NOT passed into "shell_env_params" + "trial_id": 222, # NOT passed into "shell_env_params" }, "tunable_params": ["boot"], "shell_env_params": [ - "trial_id", # From "const_arg" - "idle", # From "tunable_params", == "halt" + "trial_id", # From "const_arg" + "idle", # From "tunable_params", == "halt" ], "upload": [ { @@ -57,9 +58,7 @@ def local_fileshare_env(tunable_groups: TunableGroups, "to": "$experiment_id/$trial_id/input/data_$idle.csv", }, ], - "run": [ - "echo No-op run" - ], + "run": ["echo No-op run"], "download": [ { "from": "$experiment_id/$trial_id/$idle/data.csv", @@ -73,9 +72,11 @@ def local_fileshare_env(tunable_groups: TunableGroups, return env -def test_local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService, - local_fileshare_env: LocalFileShareEnv) -> None: +def test_local_fileshare_env( + tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, + local_fileshare_env: LocalFileShareEnv, +) -> None: """ Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the upload and download sections of the config. diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index 608edbf9ef..427fe90706 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -28,7 +28,9 @@ def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> N assert data["score"] == pytest.approx(72.92, 0.01) -def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None: +def test_mock_env_no_noise( + mock_env_no_noise: MockEnv, tunable_groups: TunableGroups +) -> None: """ Check the default values of the mock environment. """ @@ -42,20 +44,33 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr assert data["score"] == pytest.approx(75.0, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 66.4), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 74.06), -]) -def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + { + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000, + }, + 66.4, + ), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + }, + 74.06, + ), + ], +) +def test_mock_env_assign( + mock_env: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, + expected_score: float, +) -> None: """ Check the benchmark values of the mock environment after the assignment. """ @@ -68,21 +83,33 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, assert data["score"] == pytest.approx(expected_score, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 67.5), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 75.1), -]) -def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + { + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000, + }, + 67.5, + ), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + }, + 75.1, + ), + ], +) +def test_mock_env_no_noise_assign( + mock_env_no_noise: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, + expected_score: float, +) -> None: """ Check the benchmark values of the noiseless mock environment after the assignment. """ diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 878531d799..6d47d1fc61 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -38,25 +38,31 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: "ssh_priv_key_path": ssh_test_server.id_rsa_path, } - service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) + service = ConfigPersistenceService( + config={"config_path": [str(files("mlos_bench.tests.config"))]} + ) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") - env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) + env = service.load_environment( + config_path, TunableGroups(), global_config=global_config, service=service + ) check_env_success( - env, env.tunable_params, + env, + env.tunable_params, expected_results={ "hostname": ssh_test_server.service_name, "username": ssh_test_server.username, "score": 0.9, - "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" + "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" "test_param": "unset", "FOO": "unset", "ssh_username": "unset", }, expected_telemetry=[], ) - assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ - "output-downloaded.csv should have been cleaned up by temp_dir context" + assert not os.path.exists( + os.path.join(os.getcwd(), "output-downloaded.csv") + ), "output-downloaded.csv should have been cleaned up by temp_dir context" if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index 377bc940a0..fc00e5cb65 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -40,16 +40,21 @@ def __enter__(self) -> None: self.EVENT_LOOP_CONTEXT.enter() self._in_context = True - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: assert self._in_context self.EVENT_LOOP_CONTEXT.exit() self._in_context = False return False -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_event_loop_context() -> None: """Test event loop context background thread setup/cleanup handling.""" # pylint: disable=protected-access,too-many-statements @@ -85,14 +90,20 @@ def test_event_loop_context() -> None: with event_loop_caller_instance_2: assert event_loop_caller_instance_2._in_context assert event_loop_caller_instance_1._in_context - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 + ) # We should only get one thread for all instances. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ + ) + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop + ) assert not event_loop_caller_instance_2._in_context @@ -104,31 +115,43 @@ def test_event_loop_context() -> None: assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 # Once we exit the last context, the background thread should be stopped # and unusable for running co-routines. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None + ) assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() # Check that the event loop has no more tasks. - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready") # Windows ProactorEventLoopPolicy adds a dummy task. - if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop): + if sys.platform == "win32" and isinstance( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop + ): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) - raise ValueError(f"Future should not have been available to wait on {future.result()}") + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) + raise ValueError( + f"Future should not have been available to wait on {future.result()}" + ) # Test that when re-entering the context we have the same event loop. with event_loop_caller_instance_1: @@ -138,12 +161,14 @@ def test_event_loop_context() -> None: # Test running again. start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 90aa7e08f7..25abf659ce 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -14,19 +14,33 @@ @pytest.mark.parametrize( - ("argv", "expected_score"), [ - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", - "--trial_config_repeat_count", "5", - "--mock_env_seed", "-1", # Deterministic Mock Environment. - ], 67.40329), - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", - "--trial_config_repeat_count", "3", - "--max_suggestions", "3", - "--mock_env_seed", "42", # Noisy Mock Environment. - ], 64.53897), - ] + ("argv", "expected_score"), + [ + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", + "--trial_config_repeat_count", + "5", + "--mock_env_seed", + "-1", # Deterministic Mock Environment. + ], + 67.40329, + ), + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", + "--trial_config_repeat_count", + "3", + "--max_suggestions", + "3", + "--mock_env_seed", + "42", # Noisy Mock Environment. + ], + 64.53897, + ), + ], ) def test_main_bench(argv: List[str], expected_score: float) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 634050d099..39a9ae1a9b 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -48,8 +48,8 @@ def config_paths() -> List[str]: """ return [ path_join(os.getcwd(), abs_path=True), - str(files('mlos_bench.config')), - str(files('mlos_bench.tests.config')), + str(files("mlos_bench.config")), + str(files("mlos_bench.tests.config")), ] @@ -64,20 +64,23 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] + environ["USER"] = environ["USERNAME"] # This is part of the minimal required args by the Launcher. - env_conf_path = 'environments/mock/mock_env.jsonc' - cli_args = '--config-paths ' + ' '.join(config_paths) + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - ' --scheduler schedulers/sync_scheduler.jsonc' + \ - f' --environment {env_conf_path}' + \ - ' --globals globals/global_test_config.jsonc' + \ - ' --globals globals/global_test_extra_config.jsonc' \ - ' --test_global_value_2 from-args' + env_conf_path = "environments/mock/mock_env.jsonc" + cli_args = ( + "--config-paths " + + " ".join(config_paths) + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + " --scheduler schedulers/sync_scheduler.jsonc" + + f" --environment {env_conf_path}" + + " --globals globals/global_test_config.jsonc" + + " --globals globals/global_test_extra_config.jsonc" + " --test_global_value_2 from-args" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -85,30 +88,35 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsLocalExec) assert isinstance(launcher.service, SupportsRemoteExec) # Check that the first --globals file is loaded and $var expansion is handled. - assert launcher.global_config['experiment_id'] == 'MockExperiment' - assert launcher.global_config['testVmName'] == 'MockExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockExperiment" + assert launcher.global_config["testVmName"] == "MockExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" # Check that the second --globals file is loaded. - assert launcher.global_config['test_global_value'] == 'from-file' + assert launcher.global_config["test_global_value"] == "from-file" # Check overriding values in a file from the command line. - assert launcher.global_config['test_global_value_2'] == 'from-args' + assert launcher.global_config["test_global_value_2"] == "from-args" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join( + launcher.global_config["pathVarWithEnvVarRef"], abs_path=True + ) == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert launcher.teardown # Check that the environment that got loaded looks to be of the right type. - env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + env_config = launcher.config_loader.load_config( + env_conf_path, ConfigSchema.ENVIRONMENT + ) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. assert launcher.optimizer.tunable_params.is_defaults() - assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer + assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) - assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access + assert ( + launcher.scheduler._trial_config_repeat_count == 3 + ) # pylint: disable=protected-access assert launcher.scheduler._max_trials == -1 # pylint: disable=protected-access @@ -122,23 +130,25 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] - - config_file = 'cli/test-cli-config.jsonc' - globals_file = 'globals/global_test_config.jsonc' - cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \ - f' --config {config_file}' + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - f' --globals {globals_file}' + \ - ' --experiment_id MockeryExperiment' + \ - ' --no-teardown' + \ - ' --random-init' + \ - ' --random-seed 1234' + \ - ' --trial-config-repeat-count 5' + \ - ' --max_trials 200' + environ["USER"] = environ["USERNAME"] + + config_file = "cli/test-cli-config.jsonc" + globals_file = "globals/global_test_config.jsonc" + cli_args = ( + " ".join([f"--config-path {config_path}" for config_path in config_paths]) + + f" --config {config_file}" + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + f" --globals {globals_file}" + + " --experiment_id MockeryExperiment" + + " --no-teardown" + + " --random-init" + + " --random-seed 1234" + + " --trial-config-repeat-count 5" + + " --max_trials 200" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -148,35 +158,48 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsRemoteExec) # Check that the --globals file is loaded and $var expansion is handled # using the value provided on the CLI. - assert launcher.global_config['experiment_id'] == 'MockeryExperiment' - assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockeryExperiment" + assert launcher.global_config["testVmName"] == "MockeryExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join( + launcher.global_config["pathVarWithEnvVarRef"], abs_path=True + ) == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert not launcher.teardown config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI) - assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']] + assert launcher.config_loader.config_paths == [ + path_join(path, abs_path=True) for path in config_paths + config["config_path"] + ] # Check that the environment that got loaded looks to be of the right type. - env_config_file = config['environment'] - env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + env_config_file = config["environment"] + env_config = launcher.config_loader.load_config( + env_config_file, ConfigSchema.ENVIRONMENT + ) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) - opt_config_file = config['optimizer'] - opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER) - globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS) + opt_config_file = config["optimizer"] + opt_config = launcher.config_loader.load_config( + opt_config_file, ConfigSchema.OPTIMIZER + ) + globals_file_config = launcher.config_loader.load_config( + globals_file, ConfigSchema.GLOBALS + ) # The actual global_config gets overwritten as a part of processing, so to test # this we read the original value out of the source files. - orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100)) - assert launcher.optimizer.max_iterations \ - == orig_max_iters \ - == launcher.global_config['max_suggestions'] + orig_max_iters = globals_file_config.get( + "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100) + ) + assert ( + launcher.optimizer.max_iterations + == orig_max_iters + == launcher.global_config["max_suggestions"] + ) # Check that the optimizer got initialized with random values instead of the defaults. # Note: the environment doesn't get updated until suggest() is called to @@ -189,16 +212,18 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # Check that CLI parameter overrides JSON config: assert isinstance(launcher.scheduler, SyncScheduler) - assert launcher.scheduler._trial_config_repeat_count == 5 # pylint: disable=protected-access + assert ( + launcher.scheduler._trial_config_repeat_count == 5 + ) # pylint: disable=protected-access assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. - assert config['random_seed'] == 42 + assert config["random_seed"] == 42 # TODO: This isn't actually respected yet because the `--random-init` only # applies to a temporary Optimizer used to populate the initial values via # random sampling. # assert launcher.optimizer.seed == 1234 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__, "-n1"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 591501d275..508923f37d 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -31,16 +31,24 @@ def local_exec_service() -> LocalExecService: """ Test fixture for LocalExecService. """ - return LocalExecService(parent=ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - })) + return LocalExecService( + parent=ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) + ) -def _launch_main_app(root_path: str, local_exec_service: LocalExecService, - cli_config: str, re_expected: List[str]) -> None: +def _launch_main_app( + root_path: str, + local_exec_service: LocalExecService, + cli_config: str, + re_expected: List[str], +) -> None: """ Run mlos_bench command-line application with given config and check the results in the log. @@ -52,10 +60,13 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService, # temp_dir = '/tmp' log_path = path_join(temp_dir, "mock-test.log") (return_code, _stdout, _stderr) = local_exec_service.local_exec( - ["./mlos_bench/mlos_bench/run.py" + - " --config_path ./mlos_bench/mlos_bench/tests/config/" + - f" {cli_config} --log_file '{log_path}'"], - cwd=root_path) + [ + "./mlos_bench/mlos_bench/run.py" + + " --config_path ./mlos_bench/mlos_bench/tests/config/" + + f" {cli_config} --log_file '{log_path}'" + ], + cwd=root_path, + ) assert return_code == 0 try: @@ -73,65 +84,73 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService, _RE_DATE = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}" -def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecService) -> None: +def test_launch_main_app_bench( + root_path: str, local_exec_service: LocalExecService +) -> None: """ Run mlos_bench command-line application with mock benchmark config and default tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", + ], ) def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService) -> None: + root_path: str, local_exec_service: LocalExecService +) -> None: """ Run mlos_bench command-line application with mock benchmark config and user-specified tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --tunable_values tunable-values/tunable-values-example.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --tunable_values tunable-values/tunable-values-example.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", + ], ) -def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecService) -> None: +def test_launch_main_app_opt( + root_path: str, local_exec_service: LocalExecService +) -> None: """ Run mlos_bench command-line application with mock optimization config and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - "--config cli/mock-opt.jsonc" + - " --trial_config_repeat_count 3" + - " --max_suggestions 3" + - " --mock_env_seed 42", # Noisy Mock Environment. + root_path, + local_exec_service, + "--config cli/mock-opt.jsonc" + + " --trial_config_repeat_count 3" + + " --max_suggestions 3" + + " --mock_env_seed 42", # Noisy Mock Environment. [ # Iteration 1: Expect first value to be the baseline - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", # Iteration 2: The result may not always be deterministic - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Iteration 3: non-deterministic (depends on the optimizer) - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Final result: baseline is the optimum for the mock environment - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", + ], ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 59a0fac13b..6e22350d00 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -23,29 +23,29 @@ def mock_configs() -> List[dict]: """ return [ { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_latency_ns': 1000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_latency_ns": 1000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_latency_ns': 2000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_latency_ns": 2000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': -1, # Special value - 'kernel_sched_latency_ns': 3000000, + "vmSize": "Standard_B4ms", + "idle": "mwait", + "kernel_sched_migration_cost_ns": -1, # Special value + "kernel_sched_latency_ns": 3000000, }, { - 'vmSize': 'Standard_B2s', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_latency_ns': 4000000, - } + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_latency_ns": 4000000, + }, ] @@ -61,7 +61,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: "optimization_targets": {"score": "min"}, "max_suggestions": 5, "start_with_defaults": False, - "seed": SEED + "seed": SEED, }, ) @@ -77,7 +77,7 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: config={ "optimization_targets": {"score": "min"}, "max_suggestions": 5, - "seed": SEED + "seed": SEED, }, ) @@ -93,7 +93,7 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: config={ "optimization_targets": {"score": "max"}, "max_suggestions": 10, - "seed": SEED + "seed": SEED, }, ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 9e9ce25d6f..cceac9099b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -20,6 +20,7 @@ # pylint: disable=redefined-outer-name + @pytest.fixture def grid_search_tunables_config() -> dict: """ @@ -51,14 +52,27 @@ def grid_search_tunables_config() -> dict: @pytest.fixture -def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: +def grid_search_tunables_grid( + grid_search_tunables: TunableGroups, +) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. Used to check that the grids are the same (ignoring order). """ - tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] - tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) - return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) + tunables_params_values = [ + tunable.values + for tunable, _group in grid_search_tunables + if tunable.values is not None + ] + tunable_names = tuple( + tunable.name + for tunable, _group in grid_search_tunables + if tunable.values is not None + ) + return list( + dict(zip(tunable_names, combo)) + for combo in itertools.product(*tunables_params_values) + ) @pytest.fixture @@ -70,8 +84,10 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture -def grid_search_opt(grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: +def grid_search_opt( + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> GridSearchOptimizer: """ Test fixture for grid search optimizer. """ @@ -79,20 +95,27 @@ def grid_search_opt(grid_search_tunables: TunableGroups, # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. max_iterations = len(grid_search_tunables_grid) * 2 - 3 - return GridSearchOptimizer(tunables=grid_search_tunables, config={ - "max_suggestions": max_iterations, - "optimization_targets": {"score": "max", "other_score": "min"}, - }) + return GridSearchOptimizer( + tunables=grid_search_tunables, + config={ + "max_suggestions": max_iterations, + "optimization_targets": {"score": "max", "other_score": "min"}, + }, + ) -def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search_grid( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ # Check the size. - expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables) + expected_grid_size = math.prod( + tunable.cardinality for tunable, _group in grid_search_tunables + ) assert expected_grid_size > len(grid_search_tunables) assert len(grid_search_tunables_grid) == expected_grid_size # Check for specific example configs inclusion. @@ -108,15 +131,23 @@ def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, # Note: ConfigSpace param name vs TunableGroup parameter name order is not # consistent, so we need to full dict comparison. assert len(grid_search_opt_pending_configs) == expected_grid_size - assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs) - assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid) + assert all( + config in grid_search_tunables_grid + for config in grid_search_opt_pending_configs + ) + assert all( + config in grid_search_opt_pending_configs + for config in grid_search_tunables_grid + ) # Order is less relevant to us, so we'll just check that the sets are the same. # assert grid_search_opt.pending_configs == grid_search_tunables_grid -def test_grid_search(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -142,8 +173,14 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables_grid.remove(default_config) assert default_config not in grid_search_opt.pending_configs - assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in grid_search_tunables_grid + for config in grid_search_opt.pending_configs + ) + assert all( + config in list(grid_search_opt.pending_configs) + for config in grid_search_tunables_grid + ) # The next suggestion should be a different element in the grid search. suggestion = grid_search_opt.suggest() @@ -156,8 +193,14 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, assert suggestion_dict not in grid_search_opt.suggested_configs grid_search_tunables_grid.remove(suggestion.get_param_values()) - assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in grid_search_tunables_grid + for config in grid_search_opt.pending_configs + ) + assert all( + config in list(grid_search_opt.pending_configs) + for config in grid_search_tunables_grid + ) # We consider not_converged as either having reached "max_suggestions" or an empty grid? @@ -223,7 +266,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: assert best_suggestion_dict not in grid_search_opt.suggested_configs best_suggestion_score: Dict[str, TunableValue] = {} - for (opt_target, opt_dir) in grid_search_opt.targets.items(): + for opt_target, opt_dir in grid_search_opt.targets.items(): val = score[opt_target] assert isinstance(val, (int, float)) best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 @@ -237,36 +280,57 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: # Check bulk register suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() in grid_search_opt.suggested_configs + for suggestion in suggested + ) # Those new suggestions also shouldn't be in the set of previously suggested configs. - assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) - - grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], - [score] * len(suggested), - [status] * len(suggested)) - - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) + assert all( + suggestion.get_param_values() not in suggested_shuffled + for suggestion in suggested + ) + + grid_search_opt.bulk_register( + [suggestion.get_param_values() for suggestion in suggested], + [score] * len(suggested), + [status] * len(suggested), + ) + + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() not in grid_search_opt.suggested_configs + for suggestion in suggested + ) best_score, best_config = grid_search_opt.get_best_observation() assert best_score == best_suggestion_score assert best_config == best_suggestion -def test_grid_search_register(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups) -> None: +def test_grid_search_register( + grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups +) -> None: """ Make sure that the `.register()` method adjusts the score signs correctly. """ assert grid_search_opt.register( - grid_search_tunables, Status.SUCCEEDED, { + grid_search_tunables, + Status.SUCCEEDED, + { "score": 1.0, "other_score": 2.0, - }) == { - "score": -1.0, # max - "other_score": 2.0, # min + }, + ) == { + "score": -1.0, # max + "other_score": 2.0, # min } assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 6549a8795c..07eec4655f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -34,7 +34,8 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: "optimizer_type": "SMAC", "seed": SEED, # "start_with_defaults": False, - }) + }, + ) @pytest.fixture @@ -45,7 +46,9 @@ def mock_scores() -> list: return [88.88, 66.66, 99.99] -def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None: +def test_llamatune_optimizer( + llamatune_opt: MlosCoreOptimizer, mock_scores: list +) -> None: """ Make sure that llamatune+smac optimizer initializes and works correctly. """ @@ -61,6 +64,6 @@ def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list assert best_score["score"] == pytest.approx(66.66, 0.01) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__]) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index 7ebba0e664..c824d9774f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -24,9 +24,9 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: An instance of a mlos_core optimizer (FLAML-based). """ test_opt_config = { - 'optimizer_type': 'FLAML', - 'max_suggestions': 10, - 'seed': SEED, + "optimizer_type": "FLAML", + "max_suggestions": 10, + "seed": SEED, } return MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -39,44 +39,44 @@ def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) assert set(df_config.columns) == { - 'kernel_sched_latency_ns', - 'kernel_sched_migration_cost_ns', - 'kernel_sched_migration_cost_ns!type', - 'kernel_sched_migration_cost_ns!special', - 'idle', - 'vmSize', + "kernel_sched_latency_ns", + "kernel_sched_migration_cost_ns", + "kernel_sched_migration_cost_ns!type", + "kernel_sched_migration_cost_ns!special", + "idle", + "vmSize", } - assert df_config.to_dict(orient='records') == [ + assert df_config.to_dict(orient="records") == [ { - 'idle': 'halt', - 'kernel_sched_latency_ns': 1000000, - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 1000000, + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'halt', - 'kernel_sched_latency_ns': 2000000, - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 2000000, + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 3000000, - 'kernel_sched_migration_cost_ns': None, # The value is special! - 'kernel_sched_migration_cost_ns!special': -1, - 'kernel_sched_migration_cost_ns!type': 'special', - 'vmSize': 'Standard_B4ms', + "idle": "mwait", + "kernel_sched_latency_ns": 3000000, + "kernel_sched_migration_cost_ns": None, # The value is special! + "kernel_sched_migration_cost_ns!special": -1, + "kernel_sched_migration_cost_ns!type": "special", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 4000000, - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B2s', + "idle": "mwait", + "kernel_sched_latency_ns": 4000000, + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B2s", }, ] diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index fc62b4ff1b..95d51cbe22 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -17,8 +17,8 @@ from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' -_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test. +_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/" +_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test. def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: @@ -26,10 +26,10 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) Test invalid max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_trials': 10, - 'max_suggestions': 11, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_trials": 10, + "max_suggestions": 11, + "seed": SEED, } with pytest.raises(AssertionError): opt = MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -41,25 +41,29 @@ def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> No Test max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_suggestions': 123, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_suggestions": 123, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) - assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions'] + assert ( + opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"] + ) -def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_absolute_output_directory( + tunable_groups: TunableGroups, +) -> None: """ Test absolute path output directory initialization of mlos_core SMAC optimizer. """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': output_dir, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": output_dir, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) @@ -67,76 +71,96 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr assert isinstance(opt._opt, SmacOptimizer) # Final portions of the path are generated by SMAC when run_name is not specified. assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - str(test_opt_config['output_directory'])) + str(test_opt_config["output_directory"]) + ) shutil.rmtree(output_dir) -def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory( + tunable_groups: TunableGroups, +) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']))) + path_join(os.getcwd(), str(test_opt_config["output_directory"])) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_run_name( + tunable_groups: TunableGroups, +) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'run_name': 'test_run', - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "run_name": "test_run", + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name']))) + path_join( + os.getcwd(), + str(test_opt_config["output_directory"]), + str(test_opt_config["run_name"]), + ) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( + tunable_groups: TunableGroups, +) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } global_config = { - 'experiment_id': 'experiment_id', + "experiment_id": "experiment_id", } opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id'])) + path_join( + os.getcwd(), + str(test_opt_config["output_directory"]), + global_config["experiment_id"], + ) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_temp_output_directory( + tunable_groups: TunableGroups, +) -> None: """ Test random output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': None, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": None, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index a94a315939..739e27114b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -20,24 +20,33 @@ def mock_configurations_no_defaults() -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, - }, 88.88), - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 117026, - "kernel_sched_latency_ns": 149827706, - }, 66.66), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 354785, - "kernel_sched_latency_ns": 795285932, - }, 99.99), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 13112, + "kernel_sched_latency_ns": 796233790, + }, + 88.88, + ), + ( + { + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 117026, + "kernel_sched_latency_ns": 149827706, + }, + 66.66, + ), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 354785, + "kernel_sched_latency_ns": 795285932, + }, + 99.99, + ), ] @@ -47,12 +56,15 @@ def mock_configurations(mock_configurations_no_defaults: list) -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 2000000, - }, 88.88), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": -1, + "kernel_sched_latency_ns": 2000000, + }, + 88.88, + ), ] + mock_configurations_no_defaults @@ -60,7 +72,7 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """ Run several iterations of the optimizer and return the best score. """ - for (tunable_values, score) in mock_configurations: + for tunable_values, score in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() assert tunables.get_param_values() == tunable_values @@ -80,8 +92,9 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, - mock_configurations_no_defaults: list) -> None: +def test_mock_optimizer_no_defaults( + mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list +) -> None: """ Make sure that mock optimizer produces consistent suggestions. """ @@ -89,7 +102,9 @@ def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None: +def test_mock_optimizer_max( + mock_opt_max: MockOptimizer, mock_configurations: list +) -> None: """ Check the maximization mode of the mock optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index bf37040f13..ccc0ba8137 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -25,10 +25,7 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: Same as `mock_config` above, but with all values converted to strings. (This can happen when we retrieve the data from storage). """ - return [ - {key: str(val) for (key, val) in config.items()} - for config in mock_configs - ] + return [{key: str(val) for (key, val) in config.items()} for config in mock_configs] @pytest.fixture @@ -52,10 +49,12 @@ def mock_status() -> List[Status]: return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] -def _test_opt_update_min(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_min( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """ Test the bulk update of the optimizer on the minimization problem. """ @@ -68,14 +67,16 @@ def _test_opt_update_min(opt: Optimizer, "vmSize": "Standard_B4ms", "idle": "mwait", "kernel_sched_migration_cost_ns": -1, - 'kernel_sched_latency_ns': 3000000, + "kernel_sched_latency_ns": 3000000, } -def _test_opt_update_max(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_max( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """ Test the bulk update of the optimizer on the maximization problem. """ @@ -88,14 +89,16 @@ def _test_opt_update_max(opt: Optimizer, "vmSize": "Standard_B2s", "idle": "mwait", "kernel_sched_migration_cost_ns": 200000, - 'kernel_sched_latency_ns': 4000000, + "kernel_sched_latency_ns": 4000000, } -def test_update_mock_min(mock_opt: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min( + mock_opt: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer on the minimization problem. """ @@ -105,64 +108,76 @@ def test_update_mock_min(mock_opt: MockOptimizer, "vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 13112, - 'kernel_sched_latency_ns': 796233790, + "kernel_sched_latency_ns": 796233790, } -def test_update_mock_min_str(mock_opt: MockOptimizer, - mock_configs_str: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min_str( + mock_opt: MockOptimizer, + mock_configs_str: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer with all-strings data. """ _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) -def test_update_mock_max(mock_opt_max: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_max( + mock_opt_max: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer on the maximization problem. """ _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) -def test_update_flaml(flaml_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml( + flaml_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) -def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml_max( + flaml_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) -def test_update_smac(smac_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac( + smac_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the SMAC optimizer. """ _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) -def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac_max( + smac_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index 2a50f95e8c..d5068e0656 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -42,12 +42,16 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: logger("tunables: %s", str(tunables)) # pylint: disable=protected-access - if isinstance(opt, MlosCoreOptimizer) and isinstance(opt._opt, SmacOptimizer): + if isinstance(opt, MlosCoreOptimizer) and isinstance( + opt._opt, SmacOptimizer + ): config = tunable_values_to_configuration(tunables) config_df = config_to_dataframe(config) logger("config: %s", str(config)) try: - logger("prediction: %s", opt._opt.surrogate_predict(configs=config_df)) + logger( + "prediction: %s", opt._opt.surrogate_predict(configs=config_df) + ) except RuntimeError: pass @@ -56,7 +60,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: (status, _ts, output) = env_context.run() assert status.is_succeeded() assert output is not None - score = output['score'] + score = output["score"] assert isinstance(score, float) assert 60 <= score <= 120 logger("score: %s", str(score)) @@ -69,8 +73,9 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: return (best_score["score"], best_tunables) -def test_mock_optimization_loop(mock_env_no_noise: MockEnv, - mock_opt: MockOptimizer) -> None: +def test_mock_optimization_loop( + mock_env_no_noise: MockEnv, mock_opt: MockOptimizer +) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -84,8 +89,9 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, } -def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, - mock_opt_no_defaults: MockOptimizer) -> None: +def test_mock_optimization_loop_no_defaults( + mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer +) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -99,8 +105,9 @@ def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, } -def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, - flaml_opt: MlosCoreOptimizer) -> None: +def test_flaml_optimization_loop( + mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer +) -> None: """ Toy optimization loop with mock environment and FLAML optimizer. """ @@ -115,8 +122,9 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, # @pytest.mark.skip(reason="SMAC is not deterministic") -def test_smac_optimization_loop(mock_env_no_noise: MockEnv, - smac_opt: MlosCoreOptimizer) -> None: +def test_smac_optimization_loop( + mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer +) -> None: """ Toy optimization loop with mock environment and SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index 1971c01799..bf4df0e6c2 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -11,8 +11,8 @@ from .remote import MockFileShareService, MockRemoteExecService, MockVMService __all__ = [ - 'MockLocalExecService', - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockLocalExecService", + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index d6cb869f09..8f51dd9f85 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -29,18 +29,24 @@ def config_persistence_service() -> ConfigPersistenceService: """ Test fixture for ConfigPersistenceService. """ - return ConfigPersistenceService({ - "config_path": [ - "./non-existent-dir/test/foo/bar", # Non-existent config path - ".", # cwd - str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) - # Shouldn't be necessary since we automatically add this. - # str(files("mlos_bench.config").joinpath("")), # Stock configs - ] - }) - - -def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: + return ConfigPersistenceService( + { + "config_path": [ + "./non-existent-dir/test/foo/bar", # Non-existent config path + ".", # cwd + str( + files("mlos_bench.tests.config").joinpath("") + ), # Test configs (relative to mlos_bench/tests) + # Shouldn't be necessary since we automatically add this. + # str(files("mlos_bench.config").joinpath("")), # Stock configs + ] + } + ) + + +def test_cwd_in_explicit_search_path( + config_persistence_service: ConfigPersistenceService, +) -> None: """ Check that CWD is in the search path in the correct place. """ @@ -65,20 +71,25 @@ def test_cwd_in_default_search_path() -> None: config_persistence_service._config_path.index(cwd, 1) -def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None: +def test_resolve_stock_path( + config_persistence_service: ConfigPersistenceService, +) -> None: """ Check if we can actually find a file somewhere in `config_path`. """ # pylint: disable=protected-access assert config_persistence_service._config_path is not None - assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path + assert ( + ConfigPersistenceService.BUILTIN_CONFIG_PATH + in config_persistence_service._config_path + ) file_path = "storage/in-memory.jsonc" path = config_persistence_service.resolve_path(file_path) assert path.endswith(file_path) assert os.path.exists(path) assert os.path.samefile( ConfigPersistenceService.BUILTIN_CONFIG_PATH, - os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) + os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]), ) @@ -92,7 +103,9 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> N assert os.path.exists(path) -def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None: +def test_resolve_path_fail( + config_persistence_service: ConfigPersistenceService, +) -> None: """ Check if non-existent file resolves without using `config_path`. """ @@ -106,8 +119,9 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No """ Check if we can successfully load a config file located relative to `config_path`. """ - tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", - ConfigSchema.TUNABLE_VALUES) + tunables_data = config_persistence_service.load_config( + "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES + ) assert tunables_data is not None assert isinstance(tunables_data, dict) assert len(tunables_data) >= 1 diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index c6dbf7c021..a09fd442fb 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -10,5 +10,5 @@ from .mock import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index 572195dcc5..dafd8ed2fe 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -56,17 +56,22 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: json.dump(params_meta, fh_meta) script_path = local_exec_service.config_loader_service.resolve_path( - "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py") + "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py" + ) - (return_code, _stdout, stderr) = local_exec_service.local_exec([ - f"{script_path} {input_file} {meta_file} {output_file}" - ], cwd=temp_dir, env=params) + (return_code, _stdout, stderr) = local_exec_service.local_exec( + [f"{script_path} {input_file} {meta_file} {output_file}"], + cwd=temp_dir, + env=params, + ) assert stderr.strip() == "" assert return_code == 0 # assert stdout.strip() == "" - with open(path_join(temp_dir, output_file), "rt", encoding="utf-8") as fh_output: + with open( + path_join(temp_dir, output_file), "rt", encoding="utf-8" + ) as fh_output: assert [ln.strip() for ln in fh_output.readlines()] == [ 'echo "40000" > /proc/sys/kernel/sched_migration_cost_ns', 'echo "800000" > /proc/sys/kernel/sched_granularity_ns', diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index bd5b3b7d7f..04f1f600f3 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -26,23 +26,23 @@ def test_split_cmdline() -> None: """ cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" assert list(split_cmdline(cmdline)) == [ - ['.', 'env.sh'], - ['&&'], - ['('], - ['echo', 'hello'], - ['&&'], - ['echo', 'world'], - ['|'], - ['tee'], - ['>'], - ['/tmp/test'], - ['||'], - ['echo', 'foo'], - ['&&'], - ['echo', '$var'], - [';'], - ['true'], - [')'], + [".", "env.sh"], + ["&&"], + ["("], + ["echo", "hello"], + ["&&"], + ["echo", "world"], + ["|"], + ["tee"], + [">"], + ["/tmp/test"], + ["||"], + ["echo", "foo"], + ["&&"], + ["echo", "$var"], + [";"], + ["true"], + [")"], ] @@ -67,8 +67,13 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: expected_cmdline = f". env.sh && {script_abspath} --input foo" subcmds_tokens = split_cmdline(orig_cmdline) # pylint: disable=protected-access - subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] - cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] + subcmds_tokens = [ + local_exec_service._resolve_cmdline_script_path(subcmd_tokens) + for subcmd_tokens in subcmds_tokens + ] + cmdline_tokens = [ + token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens + ] expanded_cmdline = " ".join(cmdline_tokens) assert expanded_cmdline == expected_cmdline @@ -89,10 +94,9 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: Run a multiline script locally and check the results. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo hello", - "echo world" - ]) + (return_code, stdout, stderr) = local_exec_service.local_exec( + ["echo hello", "echo world"] + ) assert return_code == 0 assert stdout.strip().split() == ["hello", "world"] assert stderr.strip() == "" @@ -103,12 +107,12 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: Run a multiline script locally and pass the environment variables to it. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - r"echo $var", # Unix shell - r"echo %var%" # Windows cmd - ], env={"var": "VALUE", "int_var": 10}) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd + env={"var": "VALUE", "int_var": 10}, + ) assert return_code == 0 - if sys.platform == 'win32': + if sys.platform == "win32": assert stdout.strip().split() == ["$var", "VALUE"] else: assert stdout.strip().split() == ["VALUE", "%var%"] @@ -121,23 +125,26 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """ with local_exec_service.temp_dir_context() as temp_dir: - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows - "echo '111,222' >> output.csv", - "echo '333,444' >> output.csv", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows + "echo '111,222' >> output.csv", + "echo '333,444' >> output.csv", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) - if sys.platform == 'win32': + if sys.platform == "win32": # Workaround for Python's subprocess module on Windows adding a # space inbetween the col1,col2 arg and the redirect symbol which # cmd poorly interprets as being part of the original string arg. # Without this, we get "col2 " as the second column name. - data.rename(str.rstrip, axis='columns', inplace=True) + data.rename(str.rstrip, axis="columns", inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444]) @@ -152,10 +159,13 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input: fh_input.write("hello\n") - (return_code, stdout, stderr) = local_exec_service.local_exec([ - f"echo 'world' >> {input_file}", - f"echo 'test' >> {input_file}", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + f"echo 'world' >> {input_file}", + f"echo 'test' >> {input_file}", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" @@ -169,7 +179,9 @@ def test_run_script_fail(local_exec_service: LocalExecService) -> None: """ Try to run a non-existent command. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec(["foo_bar_baz hello"]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + ["foo_bar_baz hello"] + ) assert return_code != 0 assert stdout.strip() == "" @@ -178,11 +190,13 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N """ Try to run a series of commands, one of which fails, and abort early. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code != 0 assert stdout.strip() == "hello" @@ -192,11 +206,13 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No Try to run a series of commands, one of which fails, but let it pass. """ local_exec_service.abort_on_error = False - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code == 0 assert stdout.splitlines() == [ "hello", @@ -214,13 +230,17 @@ def test_temp_dir_path_expansion() -> None: # the fact. with tempfile.TemporaryDirectory() as temp_dir: global_config = { - "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" + "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" } config = { # The temp_dir for the LocalExecService should get expanded via workdir global config. "temp_dir": "$workdir/temp", } - local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService()) + local_exec_service = LocalExecService( + config, global_config, parent=ConfigPersistenceService() + ) # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) - assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True) + assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( + temp_dir, "temp", abs_path=True + ) diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index eede9383bc..9164da60df 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -9,5 +9,5 @@ from .mock_local_exec_service import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index db8f0134c4..ad47160753 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -35,16 +35,24 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): Mock methods for LocalExecService testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, + global_config, + parent, + self.merge_methods(methods, [self.local_exec]), ) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: return (0, "", "") diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 835738015b..4ef38ab440 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -28,19 +28,24 @@ class MockServiceBase(Service, SupportsSomeMethod): """A base service class for testing.""" def __init__( - self, - config: Optional[dict] = None, - global_config: Optional[dict] = None, - parent: Optional[Service] = None, - methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None: + self, + config: Optional[dict] = None, + global_config: Optional[dict] = None, + parent: Optional[Service] = None, + methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None, + ) -> None: super().__init__( config, global_config, parent, - self.merge_methods(methods, [ - self.some_method, - self.some_other_method, - ])) + self.merge_methods( + methods, + [ + self.some_method, + self.some_other_method, + ], + ), + ) def some_method(self) -> str: """some_method""" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index e8a87ab684..df3fb69c53 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -12,7 +12,7 @@ from .mock.mock_vm_service import MockVMService __all__ = [ - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 949b712c79..2858b2388c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -18,16 +18,25 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_file( + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ - patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: + with patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client, patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client: mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -47,38 +56,45 @@ def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_folder", "is_directory": True}, - {"name": "a_file_1.csv", "is_directory": False}, - ]) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_folder", "is_directory": True}, + {"name": "a_file_1.csv", "is_directory": False}, + ] + ), ), f"{remote_folder}/a_folder": Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_file_2.csv", "is_directory": False}, - ]) - ), - f"{remote_folder}/a_file_1.csv": Mock( - exists=Mock(return_value=False) - ), - f"{remote_folder}/a_folder/a_file_2.csv": Mock( - exists=Mock(return_value=False) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_file_2.csv", "is_directory": False}, + ] + ), ), + f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)), + f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)), } @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive(mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_non_recursive( + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] @@ -87,47 +103,69 @@ def test_download_folder_non_recursive(mock_makedirs: MagicMock, mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - ], any_order=True) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_recursive( + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] azure_fileshare.download(config, remote_folder, local_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_upload_file( + mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access mock_isdir.return_value = False config: dict = {} @@ -143,6 +181,7 @@ def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshar class MyDirEntry: # pylint: disable=too-few-public-methods """Dummy class for os.DirEntry""" + def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir @@ -176,7 +215,7 @@ def process_paths(input_path: str) -> str: skip_prefix = os.getcwd() # Remove prefix from os.path.abspath if there if input_path == os.path.abspath(input_path): - result = input_path[len(skip_prefix) + 1:] + result = input_path[len(skip_prefix) + 1 :] else: result = input_path # Change file seps to unix-style @@ -186,17 +225,21 @@ def process_paths(input_path: str) -> str: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_non_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: @@ -208,23 +251,30 @@ def test_upload_directory_non_recursive(mock_scandir: MagicMock, @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = ( + azure_fileshare._share_client + ) # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index d6d55d3975..7a7a87359a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -18,27 +18,41 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_network_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_network_service: AzureNetworkService) -> None: +def test_wait_network_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_network_service: AzureNetworkService, +) -> None: """ Test retries of the network deployment operation. """ # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ - make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Running"}} + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Running"}} + ), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Succeeded"}} + ), ] (status, _) = azure_network_service.wait_network_deployment( @@ -49,30 +63,37 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("deprovision_network", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), # These should succeed since we set ignore_errors=True by default (401, Status.SUCCEEDED), (404, Status.SUCCEEDED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_network_operation_status(mock_requests: MagicMock, - azure_network_service: AzureNetworkService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_network_operation_status( + mock_requests: MagicMock, + azure_network_service: AzureNetworkService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """ Test network operation status. """ @@ -84,27 +105,37 @@ def test_network_operation_status(mock_requests: MagicMock, with pytest.raises(ValueError): # Missing vnetName should raise ValueError (status, _) = operation({}) if accepts_params else operation() - (status, _) = operation({"vnetName": "test-vnet"}) if accepts_params else operation() + (status, _) = ( + operation({"vnetName": "test-vnet"}) if accepts_params else operation() + ) assert status == operation_status @pytest.fixture -def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: +def test_azure_network_service_no_deployment_template( + azure_auth_service: AzureAuthService, +) -> None: """ Tests creating a network services without a deployment template (should fail). """ with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - # "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + # "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 1d84d73cab..0fd94cf821 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -19,27 +19,41 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_host_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_host_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """ Test retries of the host deployment operation. """ # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ - make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Running"}} + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Running"}} + ), + make_httplib_json_response( + 200, {"properties": {"provisioningState": "Succeeded"}} + ), ] (status, _) = azure_vm_service.wait_host_deployment( @@ -50,11 +64,14 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status -def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: +def test_azure_vm_service_recursive_template_params( + azure_auth_service: AzureAuthService, +) -> None: """ Test expanding template params recursively. """ @@ -75,8 +92,14 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut } azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) assert azure_vm_service.deploy_params["location"] == global_config["location"] - assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' - assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + assert ( + azure_vm_service.deploy_params["vmMeta"] + == f'{global_config["vmName"]}-{global_config["location"]}' + ) + assert ( + azure_vm_service.deploy_params["vmNsg"] + == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + ) def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: @@ -98,14 +121,17 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N } with pytest.raises(ValueError): config_with_custom_data = deepcopy(config) - config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] - AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) + config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] + AzureVMService( + config_with_custom_data, global_config, parent=azure_auth_service + ) azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) - assert azure_vm_service.deploy_params['customData'] + assert azure_vm_service.deploy_params["customData"] @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("start_host", True), ("stop_host", True), ("shutdown", True), @@ -113,22 +139,27 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N ("deallocate_host", True), ("restart_host", True), ("reboot", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_vm_operation_status(mock_requests: MagicMock, - azure_vm_service: AzureVMService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_vm_operation_status( + mock_requests: MagicMock, + azure_vm_service: AzureVMService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """ Test VM operation status. """ @@ -145,12 +176,16 @@ def test_vm_operation_status(mock_requests: MagicMock, @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("provision_host", True), - ]) -def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, - operation_name: str, - accepts_params: bool) -> None: + ], +) +def test_vm_operation_invalid( + azure_vm_service_remote_exec_only: AzureVMService, + operation_name: str, + accepts_params: bool, +) -> None: """ Test VM operation status for an incomplete service config. """ @@ -161,8 +196,9 @@ def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_ready( + mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService +) -> None: """ Test waiting for the completion of the remote VM operation. """ @@ -183,14 +219,15 @@ def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, status, _ = azure_vm_service.wait_host_operation(params) - assert (async_url, ) == mock_session.return_value.get.call_args[0] - assert (retry_after, ) == mock_sleep.call_args[0] + assert (async_url,) == mock_session.return_value.get.call_args[0] + assert (retry_after,) == mock_sleep.call_args[0] assert status.is_succeeded() @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_timeout(mock_session: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_timeout( + mock_session: MagicMock, azure_vm_service: AzureVMService +) -> None: """ Test the time out of the remote VM operation. """ @@ -198,7 +235,7 @@ def test_wait_vm_operation_timeout(mock_session: MagicMock, params = { "asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", - "pollInterval": 1 + "pollInterval": 1, } mock_status_response = MagicMock(status_code=200) @@ -212,16 +249,20 @@ def test_wait_vm_operation_timeout(mock_session: MagicMock, @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_vm_operation_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """ Test the retries of the remote VM operation. """ @@ -229,8 +270,12 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), ] @@ -241,20 +286,27 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, "requestTotalRetries": total_retries, "asyncResultsUrl": "https://DUMMY_ASYNC_URL", "vmName": "test-vm", - }) + } + ) assert status == operation_status @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, - http_status_code: int, operation_status: Status) -> None: +def test_remote_exec_status( + mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService, + http_status_code: int, + operation_status: Status, +) -> None: """ Test waiting for completion of the remote execution on Azure. """ @@ -262,19 +314,24 @@ def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_ex mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) + status, _ = azure_vm_service_remote_exec_only.remote_exec( + script, config={"vmName": "test-vm"}, env_params={} + ) assert status == operation_status @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output(mock_requests: MagicMock, - azure_vm_service_remote_exec_only: AzureVMService) -> None: +def test_remote_exec_headers_output( + mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService +) -> None: """ Check if HTTP headers from the remote execution on Azure are correct. """ @@ -284,18 +341,22 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, mock_response = MagicMock() mock_response.status_code = 202 - mock_response.headers = { - "Azure-AsyncOperation": async_url_value - } - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.headers = {"Azure-AsyncOperation": async_url_value} + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ - "param_1": 123, - "param_2": "abc", - }) + _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( + script, + config={"vmName": "test-vm"}, + env_params={ + "param_1": 123, + "param_2": "abc", + }, + ) assert async_url_key in cmd_output assert cmd_output[async_url_key] == async_url_value @@ -305,13 +366,14 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, "script": script, "parameters": [ {"name": "param_1", "value": 123}, - {"name": "param_2", "value": "abc"} - ] + {"name": "param_2", "value": "abc"}, + ], } @pytest.mark.parametrize( - ("operation_status", "wait_output", "results_output"), [ + ("operation_status", "wait_output", "results_output"), + [ ( Status.SUCCEEDED, { @@ -323,13 +385,18 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, } } }, - {"stdout": "DUMMY_STDOUT_STDERR"} + {"stdout": "DUMMY_STDOUT_STDERR"}, ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), - ]) -def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, - wait_output: dict, results_output: dict) -> None: + ], +) +def test_get_remote_exec_results( + azure_vm_service_remote_exec_only: AzureVMService, + operation_status: Status, + wait_output: dict, + results_output: dict, +) -> None: """ Test getting the results of the remote execution on Azure. """ @@ -338,9 +405,15 @@ def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMServi mock_wait_host_operation = MagicMock() mock_wait_host_operation.return_value = (operation_status, wait_output) # azure_vm_service.wait_host_operation = mock_wait_host_operation - setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation) - - status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) + setattr( + azure_vm_service_remote_exec_only, + "wait_host_operation", + mock_wait_host_operation, + ) + + status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results( + params + ) assert status == operation_status assert mock_wait_host_operation.call_args[0][0] == params diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 2794bb01cf..1e997fc795 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -30,12 +30,16 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture -def azure_auth_service(config_persistence_service: ConfigPersistenceService, - monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: +def azure_auth_service( + config_persistence_service: ConfigPersistenceService, + monkeypatch: pytest.MonkeyPatch, +) -> AzureAuthService: """ Creates a dummy AzureAuthService for tests that require it. """ - auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service) + auth = AzureAuthService( + config={}, global_config={}, parent=config_persistence_service + ) monkeypatch.setattr(auth, "get_access_token", lambda: "TEST_TOKEN") return auth @@ -45,19 +49,23 @@ def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkS """ Creates a dummy Azure VM service for tests that require it. """ - return AzureNetworkService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureNetworkService( + config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VNET", - "vnetName": "test-vnet", # Should come from the upper-level config - }, parent=azure_auth_service) + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VNET", + "vnetName": "test-vnet", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture @@ -65,44 +73,60 @@ def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: """ Creates a dummy Azure VM service for tests that require it. """ - return AzureVMService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureVMService( + config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VM", + "vmName": "test-vm", # Should come from the upper-level config }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VM", - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + parent=azure_auth_service, + ) @pytest.fixture -def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> AzureVMService: +def azure_vm_service_remote_exec_only( + azure_auth_service: AzureAuthService, +) -> AzureVMService: """ Creates a dummy Azure VM service with no deployment template. """ - return AzureVMService(config={ - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "pollInterval": 1, - "pollTimeout": 2, - }, global_config={ - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + return AzureVMService( + config={ + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "vmName": "test-vm", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture -def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: +def azure_fileshare( + config_persistence_service: ConfigPersistenceService, +) -> AzureFileShareService: """ Creates a dummy AzureFileShareService for tests that require it. """ with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): - return AzureFileShareService(config={ - "storageAccountName": "TEST_ACCOUNT_NAME", - "storageFileShareName": "TEST_FS_NAME", - "storageAccountKey": "TEST_ACCOUNT_KEY" - }, global_config={}, parent=config_persistence_service) + return AzureFileShareService( + config={ + "storageAccountName": "TEST_ACCOUNT_NAME", + "storageFileShareName": "TEST_FS_NAME", + "storageAccountKey": "TEST_ACCOUNT_KEY", + }, + global_config={}, + parent=config_persistence_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index b9474f0709..fb1c4ee39b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -20,16 +20,24 @@ class MockAuthService(Service, SupportsAuth): A collection Service functions for mocking authentication ops. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) def get_access_token(self) -> str: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 1a026966a8..79f8c608c2 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -21,21 +21,30 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): A collection Service functions for mocking file share ops. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) self._upload: List[Tuple[str, str]] = [] self._download: List[Tuple[str, str]] = [] - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: self._upload.append((local_path, remote_path)) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index e6169d9f93..6bf9fc8d05 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -20,10 +20,13 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): Mock Network service for testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock network services proxy. @@ -38,13 +41,19 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsNetworkProvisioning: - "provision_network", - "deprovision_network", - "wait_network_deployment", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsNetworkProvisioning: + "provision_network", + "deprovision_network", + "wait_network_deployment", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index ee99251c64..38d759f53c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -18,10 +18,13 @@ class MockRemoteExecService(Service, SupportsRemoteExec): Mock remote script execution service. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock remote exec service. @@ -36,9 +39,14 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - "remote_exec": mock_operation, - "get_remote_exec_results": mock_operation, - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + "remote_exec": mock_operation, + "get_remote_exec_results": mock_operation, + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index a44edaf080..3ae13cf6a6 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -20,10 +20,13 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, Supports Mock VM service for testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock VM services proxy. @@ -38,23 +41,29 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsHostProvisioning: - "wait_host_deployment", - "provision_host", - "deprovision_host", - "deallocate_host", - # SupportsHostOps: - "start_host", - "stop_host", - "restart_host", - "wait_host_operation", - # SupportsOsOps: - "shutdown", - "reboot", - "wait_os_operation", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsHostProvisioning: + "wait_host_deployment", + "provision_host", + "deprovision_host", + "deallocate_host", + # SupportsHostOps: + "start_host", + "stop_host", + "restart_host", + "wait_host_operation", + # SupportsOsOps: + "shutdown", + "reboot", + "wait_os_operation", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index e0060d8047..c893adfd4a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -17,9 +17,9 @@ # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 -SSH_TEST_SERVER_NAME = 'ssh-server' -ALT_TEST_SERVER_NAME = 'alt-server' -REBOOT_TEST_SERVER_NAME = 'reboot-server' +SSH_TEST_SERVER_NAME = "ssh-server" +ALT_TEST_SERVER_NAME = "alt-server" +REBOOT_TEST_SERVER_NAME = "reboot-server" @dataclass @@ -42,8 +42,12 @@ def get_port(self, uncached: bool = False) -> int: Note: this value can change when the service restarts so we can't rely on the DockerServices. """ if self._port is None or uncached: - port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", - shell=True, check=True, capture_output=True) + port_cmd = run( + f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", + shell=True, + check=True, + capture_output=True, + ) self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) return self._port @@ -68,7 +72,9 @@ def to_connect_params(self, uncached: bool = False) -> dict: } -def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: +def wait_docker_service_socket( + docker_services: DockerServices, hostname: str, port: int +) -> None: """Wait until a docker service is ready.""" docker_services.wait_until_responsive( check=lambda: check_socket(hostname, port), diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 6f05fe953b..8b28856396 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -30,26 +30,28 @@ # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = 'host.docker.internal' +HOST_DOCKER_NAME = "host.docker.internal" @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): # On Linux, if we're running in a docker container, we can use the # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME # Docker (Desktop) for Windows (WSL2) uses a special networking magic # to refer to the host machine as `localhost` when exposing ports. # In all other cases, assume we're executing directly inside conda on the host. - return 'localhost' + return "localhost" @pytest.fixture(scope="session") -def ssh_test_server(ssh_test_server_hostname: str, - docker_compose_project_name: str, - locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: +def ssh_test_server( + ssh_test_server_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices, +) -> Generator[SshTestServerInfo, None, None]: """ Fixture for getting the ssh test server services setup via docker-compose using pytest-docker. @@ -66,23 +68,37 @@ def ssh_test_server(ssh_test_server_hostname: str, compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, hostname=ssh_test_server_hostname, - username='root', - id_rsa_path=id_rsa_file.name) - wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) + username="root", + id_rsa_path=id_rsa_file.name, + ) + wait_docker_service_socket( + locked_docker_services, + ssh_test_server_info.hostname, + ssh_test_server_info.get_port(), + ) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" - cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) + cmd = run( + docker_cp_cmd.split(), + check=True, + cwd=os.path.dirname(__file__), + capture_output=True, + text=True, + ) if cmd.returncode != 0: - raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " - + f"[return={cmd.returncode}]: {str(cmd.stderr)}") + raise RuntimeError( + f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " + + f"[return={cmd.returncode}]: {str(cmd.stderr)}" + ) os.chmod(id_rsa_file.name, 0o600) yield ssh_test_server_info # NamedTempFile deleted on context exit @pytest.fixture(scope="session") -def alt_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def alt_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -95,14 +111,20 @@ def alt_test_server(ssh_test_server: SshTestServerInfo, service_name=ALT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, + alt_test_server_info.hostname, + alt_test_server_info.get_port(), + ) return alt_test_server_info @pytest.fixture(scope="session") -def reboot_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def reboot_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -115,8 +137,13 @@ def reboot_test_server(ssh_test_server: SshTestServerInfo, service_name=REBOOT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, + reboot_test_server_info.hostname, + reboot_test_server_info.get_port(), + ) return reboot_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index f2bbbe4b8a..e3b9c85746 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -52,8 +52,9 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker -def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_single_file( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -66,7 +67,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, lines = [line + "\n" for line in lines] # 1. Write a local file and upload it. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() temp_file.close() @@ -78,7 +79,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, ) # 2. Download the remote file and compare the contents. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.close() ssh_fileshare_service.download( params=config, @@ -86,14 +87,15 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, local_path=temp_file.name, ) # Download will replace the inode at that name, so we need to reopen the file. - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == lines @requires_docker -def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_recursive( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -113,14 +115,17 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, "bar", ], } - files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} + files_lines = { + path: [line + "\n" for line in lines] + for (path, lines) in files_lines.items() + } with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: # Setup the directory structure. - for (file_path, lines) in files_lines.items(): + for file_path, lines in files_lines.items(): path = Path(tempdir1, file_path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode='w+t', encoding='utf-8') as temp_file: + with open(path, mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() assert os.path.getsize(path) > 0 @@ -143,19 +148,22 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, # Compare both. # Note: remote dir name is appended to target. - assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) + assert are_dir_trees_equal( + tempdir1, path_join(tempdir2, basename(remote_file_path)) + ) @requires_docker -def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_download_file_dne( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() canary_str = "canary" - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines([canary_str]) temp_file.flush() temp_file.close() @@ -166,20 +174,22 @@ def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, remote_path="/tmp/file-dne.txt", local_path=temp_file.name, ) - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == [canary_str] @requires_docker -def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_upload_file_dne( + ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService, +) -> None: """Test the SshFileShareService single file upload that doesn't exist.""" with ssh_host_service, ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() - path = '/tmp/upload-file-src-dne.txt' + path = "/tmp/upload-file-src-dne.txt" with pytest.raises(OSError): ssh_fileshare_service.upload( params=config, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 4c8e5e0c66..6cea52a102 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -27,9 +27,11 @@ @requires_docker -def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, - alt_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_remote_exec( + ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """ Test the SshHostService remote_exec. @@ -42,7 +44,11 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) + connection_client = ( + ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( + connection_id + ) + ) assert connection_client is None (status, results_info) = ssh_host_service.remote_exec( @@ -57,7 +63,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, assert results["stdout"].strip() == SSH_TEST_SERVER_NAME # Check that the client caching is behaving as expected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ( + ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + ) assert connection is not None assert connection._username == ssh_test_server.username assert connection._host == ssh_test_server.hostname @@ -91,13 +99,15 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) - assert status.is_failed() # should retain exit code from "false" + assert status.is_failed() # should retain exit code from "false" stdout = str(results["stdout"]) assert stdout.splitlines() == [ "BAR=bar", "UNUSED=", ] - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ( + ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + ) assert connection._local_port == local_port # Close the connection (gracefully) @@ -114,7 +124,7 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, config=config, # Also test interacting with environment_variables. env_params={ - 'FOO': 'foo', + "FOO": "foo", }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) @@ -127,17 +137,21 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, "BAZ=", ] # Make sure it looks like we reconnected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ( + ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + ) assert connection._local_port != local_port # Make sure the cache is cleaned up on context exit. assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 -def check_ssh_service_reboot(docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - graceful: bool) -> None: +def check_ssh_service_reboot( + docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool, +) -> None: """ Check the SshHostService reboot operation. """ @@ -146,12 +160,14 @@ def check_ssh_service_reboot(docker_services: DockerServices, # Also, it may cause issues with other parallel unit tests, so we run it as # a part of the same unit test for now. with ssh_host_service: - reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) + reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config( + uncached=True + ) (status, results_info) = ssh_host_service.remote_exec( script=[ 'echo "sleeping..."', - 'sleep 30', - 'echo "should not reach this point"' + "sleep 30", + 'echo "should not reach this point"', ], config=reboot_test_srv_ssh_svc_conf, env_params={}, @@ -161,11 +177,14 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # Now try to restart the server. - (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf, - force=not graceful) + (status, reboot_results_info) = ssh_host_service.reboot( + params=reboot_test_srv_ssh_svc_conf, force=not graceful + ) assert status.is_pending() - (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) + (status, reboot_results_info) = ssh_host_service.wait_os_operation( + reboot_results_info + ) # NOTE: reboot/shutdown ops mostly return FAILED, even though the reboot succeeds. _LOG.debug("reboot status: %s: %s", status, reboot_results_info) @@ -183,19 +202,34 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # try to reconnect and see if the port changed try: - run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False) + run_res = run( + "docker ps | grep mlos_bench-test- | grep reboot", + shell=True, + capture_output=True, + check=False, + ) print(run_res.stdout.decode()) print(run_res.stderr.decode()) - reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True) - if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]: + reboot_test_srv_ssh_svc_conf_new = ( + reboot_test_server.to_ssh_service_config(uncached=True) + ) + if ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ): break except CalledProcessError as ex: _LOG.info("Failed to check port for reboot test server: %s", ex) - assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"] + assert ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ) - wait_docker_service_socket(docker_services, - reboot_test_server.hostname, - reboot_test_srv_ssh_svc_conf_new["ssh_port"]) + wait_docker_service_socket( + docker_services, + reboot_test_server.hostname, + reboot_test_srv_ssh_svc_conf_new["ssh_port"], + ) (status, results_info) = ssh_host_service.remote_exec( script=["hostname"], @@ -208,12 +242,18 @@ def check_ssh_service_reboot(docker_services: DockerServices, @requires_docker -def test_ssh_service_reboot(locked_docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_reboot( + locked_docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """ Test the SshHostService reboot operation. """ # Grouped together to avoid parallel runner interactions. - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=True + ) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=False + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index 7bee929fea..b8e489b030 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -35,7 +35,9 @@ # We replaced pytest-lazy-fixture with pytest-lazy-fixtures: # https://github.com/TvoroG/pytest-lazy-fixture/issues/65 if version("pytest-lazy-fixture"): - raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.") + raise UserWarning( + "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it." + ) except PackageNotFoundError: # OK: pytest-lazy-fixture not installed pass @@ -43,12 +45,16 @@ @requires_docker @requires_ssh -@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ - (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), - (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), -]) -def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, - server_name: str) -> None: +@pytest.mark.parametrize( + ["ssh_test_server_info", "server_name"], + [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), + ], +) +def test_ssh_service_test_infra( + ssh_test_server_info: SshTestServerInfo, server_name: str +) -> None: """Check for the pytest-docker ssh test infra.""" assert ssh_test_server_info.service_name == server_name @@ -57,17 +63,18 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, local_port = ssh_test_server_info.get_port() assert check_socket(ip_addr, local_port) - ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ - + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + ssh_cmd = ( + "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " + f"-p {local_port} {ssh_test_server_info.hostname} hostname" - cmd = run(ssh_cmd.split(), - capture_output=True, - text=True, - check=True) + ) + cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True) assert cmd.stdout.strip() == server_name -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. @@ -93,31 +100,43 @@ def test_ssh_service_context_handler() -> None: time.sleep(0.25) assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - ssh_fileshare_service = SshFileShareService(config={}, global_config={}, parent=None) + ssh_fileshare_service = SshFileShareService( + config={}, global_config={}, parent=None + ) assert ssh_fileshare_service assert not ssh_fileshare_service._in_context with ssh_fileshare_service: assert ssh_fileshare_service._in_context assert ssh_host_service._in_context - assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ - is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + SshService._EVENT_LOOP_CONTEXT._event_loop_thread + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread - assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ - is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + ) + assert ( + SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + ) assert not ssh_fileshare_service._in_context # And that instance should be unusable after we are outside the context. - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo')) - raise ValueError(f"Future should not have been available to wait on {future.result()}") + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = ssh_fileshare_service._run_coroutine( + asyncio.sleep(0.1, result="foo") + ) + raise ValueError( + f"Future should not have been available to wait on {future.result()}" + ) # The background thread should remain running since we have another context still open. assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py index 463879634f..31daec07c3 100644 --- a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py +++ b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py @@ -40,7 +40,10 @@ def test_service_method_register_without_constructor() -> None: # somehow having it in a different scope makes a difference if isinstance(mixin_service, SupportsSomeMethod): assert mixin_service.some_method() == f"{some_base_service}: base.some_method" - assert mixin_service.some_other_method() == f"{some_base_service}: base.some_other_method" + assert ( + mixin_service.some_other_method() + == f"{some_base_service}: base.some_other_method" + ) # register the child service mixin_service.register(some_child_service.export()) @@ -48,6 +51,9 @@ def test_service_method_register_without_constructor() -> None: assert mixin_service._services == {some_child_service} # check that the inheritance works as expected assert mixin_service.some_method() == f"{some_child_service}: child.some_method" - assert mixin_service.some_other_method() == f"{some_child_service}: base.some_other_method" + assert ( + mixin_service.some_other_method() + == f"{some_child_service}: base.some_other_method" + ) else: assert False diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 2c16df65c4..7b859e79ba 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -18,8 +18,12 @@ exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials -exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials -mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +exp_no_tunables_storage_with_trials = ( + sql_storage_fixtures.exp_no_tunables_storage_with_trials +) +mixed_numerics_exp_storage_with_trials = ( + sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +) exp_data = sql_storage_fixtures.exp_data exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 8159043be1..852155a8c6 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -22,23 +22,32 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) assert exp.objectives == exp_storage.opt_targets -def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: +def test_exp_data_root_env_config( + exp_storage: Storage.Experiment, exp_data: ExperimentData +) -> None: """Tests the root_env_config property of ExperimentData""" # pylint: disable=protected-access - assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) + assert exp_data.root_env_config == ( + exp_storage._root_env_config, + exp_storage._git_repo, + exp_storage._git_commit, + ) -def test_exp_trial_data_objectives(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_data_objectives( + storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """ Start a new trial and check the storage for the trial data. """ - trial_opt_new = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "some-other-target", - "opt_direction": "max", - }) + trial_opt_new = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "some-other-target", + "opt_direction": "max", + }, + ) assert trial_opt_new.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_new.trial_id, @@ -46,10 +55,13 @@ def test_exp_trial_data_objectives(storage: Storage, "opt_direction": "max", } - trial_opt_old = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "back-compat", - # "opt_direction": "max", # missing - }) + trial_opt_old = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "back-compat", + # "opt_direction": "max", # missing + }, + ) assert trial_opt_old.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_old.trial_id, @@ -66,7 +78,9 @@ def test_exp_trial_data_objectives(storage: Storage, } -def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: +def test_exp_data_results_df( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """Tests the results_df property of ExperimentData""" results_df = exp_data.results_df expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT @@ -74,12 +88,20 @@ def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGr assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT assert len(results_df["trial_id"].unique()) == expected_trials_count obj_target = next(iter(exp_data.objectives)) - assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count + assert ( + len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) + == expected_trials_count + ) (tunable, _covariant_group) = next(iter(tunable_groups)) - assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count + assert ( + len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) + == expected_trials_count + ) -def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: +def test_exp_data_tunable_config_trial_group_id_in_results_df( + exp_data: ExperimentData, +) -> None: """ Tests the tunable_config_trial_group_id property of ExperimentData.results_df @@ -114,15 +136,21 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: This tests bulk loading of the tunable_config_trial_groups. """ # Should be keyed by config_id. - assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) + assert list(exp_data.tunable_config_trial_groups.keys()) == list( + range(1, CONFIG_COUNT + 1) + ) # Which should match the objects. - assert [config_trial_group.tunable_config_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [ + config_trial_group.tunable_config_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT + 1)) # And the tunable_config_trial_group_id should also match the minimum trial_id. - assert [config_trial_group.tunable_config_trial_group_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) + assert [ + config_trial_group.tunable_config_trial_group_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list( + range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT) + ) def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: @@ -130,9 +158,9 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config.tunable_config_id - for config in exp_data.tunable_configs.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [ + config.tunable_config_id for config in exp_data.tunable_configs.values() + ] == list(range(1, CONFIG_COUNT + 1)) def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index d0a5edc694..91920190de 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -37,9 +37,11 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Start a trial and check that it is pending. """ @@ -50,14 +52,16 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_many(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_many( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Start THREE trials and check that both are pending. """ - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) - config2 = tunable_groups.copy().assign({'idle': 'noidle'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) + config2 = tunable_groups.copy().assign({"idle": "noidle"}) trial_ids = { exp_storage.new_trial(config1).trial_id, exp_storage.new_trial(config2).trial_id, @@ -72,9 +76,11 @@ def test_exp_trial_pending_many(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_fail( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Start a trial, fail it, and and check that it is NOT pending. """ @@ -85,9 +91,11 @@ def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_success(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_success( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Start a trial, finish it successfully, and and check that it is NOT pending. """ @@ -98,31 +106,39 @@ def test_exp_trial_success(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_categ(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_categ( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Update the trial with multiple metrics, some of which are categorical. """ trial = exp_storage.new_trial(tunable_groups) - trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) + trial.update( + Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"} + ) assert exp_storage.load() == ( [trial.trial_id], - [{ - 'idle': 'halt', - 'kernel_sched_latency_ns': '2000000', - 'kernel_sched_migration_cost_ns': '-1', - 'vmSize': 'Standard_B4ms' - }], + [ + { + "idle": "halt", + "kernel_sched_latency_ns": "2000000", + "kernel_sched_migration_cost_ns": "-1", + "vmSize": "Standard_B4ms", + } + ], [{"score": "99.9", "benchmark": "test"}], - [Status.SUCCEEDED] + [Status.SUCCEEDED], ) @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_twice( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Update the trial status twice and receive an error. """ @@ -133,9 +149,11 @@ def test_exp_trial_update_twice(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_3(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_3( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], +) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. Check that one is still pending another one can be loaded into the optimizer. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 7e346a5ccc..5d56a3e195 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -36,7 +36,7 @@ def storage() -> SqlStorage: "drivername": "sqlite", "database": ":memory:", # "database": "mlos_bench.pytest.db", - } + }, ) @@ -106,7 +106,9 @@ def mixed_numerics_exp_storage( assert not exp._in_context -def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: +def _dummy_run_exp( + exp: SqlStorage.Experiment, tunable_name: Optional[str] +) -> SqlStorage.Experiment: """ Generates data by doing a simulated run of the given experiment. """ @@ -119,47 +121,68 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S (tunable_min, tunable_max) = tunable.range tunable_range = tunable_max - tunable_min rand_seed(SEED) - opt = MockOptimizer(tunables=exp.tunables, config={ - "seed": SEED, - # This should be the default, so we leave it omitted for now to test the default. - # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) - # "start_with_defaults": True, - }) + opt = MockOptimizer( + tunables=exp.tunables, + config={ + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) + # "start_with_defaults": True, + }, + ) assert opt.start_with_defaults for config_i in range(CONFIG_COUNT): tunables = opt.suggest() for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT): - trial = exp.new_trial(tunables=tunables.copy(), config={ - "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(exp.opt_targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) + trial = exp.new_trial( + tunables=tunables.copy(), + config={ + "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(exp.opt_targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) if exp.tunables: assert trial.tunable_config_id == config_i + 1 else: assert trial.tunable_config_id == 1 if tunable_name: - tunable_value = float(tunables.get_tunable(tunable_name)[0].numerical_value) - tunable_value_norm = base_score * (tunable_value - tunable_min) / tunable_range + tunable_value = float( + tunables.get_tunable(tunable_name)[0].numerical_value + ) + tunable_value_norm = ( + base_score * (tunable_value - tunable_min) / tunable_range + ) else: tunable_value_norm = 0 timestamp = datetime.now(UTC) - trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[ - (timestamp, "some-metric", tunable_value_norm + random() / 100), - ]) - trial.update(Status.SUCCEEDED, timestamp, metrics={ - # Give some variance on the score. - # And some influence from the tunable value. - "score": tunable_value_norm + random() / 100 - }) + trial.update_telemetry( + status=Status.RUNNING, + timestamp=timestamp, + metrics=[ + (timestamp, "some-metric", tunable_value_norm + random() / 100), + ], + ) + trial.update( + Status.SUCCEEDED, + timestamp, + metrics={ + # Give some variance on the score. + # And some influence from the tunable value. + "score": tunable_value_norm + + random() / 100 + }, + ) return exp @pytest.fixture -def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def exp_storage_with_trials( + exp_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -167,7 +190,9 @@ def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Ex @pytest.fixture -def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def exp_no_tunables_storage_with_trials( + exp_no_tunables_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -176,7 +201,9 @@ def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Expe @pytest.fixture -def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def mixed_numerics_exp_storage_with_trials( + mixed_numerics_exp_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -185,7 +212,9 @@ def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorag @pytest.fixture -def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_data( + storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData. """ @@ -193,7 +222,9 @@ def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment @pytest.fixture -def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_no_tunables_data( + storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData with no tunable configs. """ @@ -201,7 +232,9 @@ def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trial @pytest.fixture -def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def mixed_numerics_exp_data( + storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData with mixed numerical tunable types. """ diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index ba965ed3c6..e8c4d38a9a 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -13,8 +13,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_pending( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """ Schedule a trial and check that it is pending and has the right configuration. """ @@ -31,13 +32,14 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, } -def test_exp_trial_configs(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_configs( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """ Start multiple trials with two different configs and check that we store only two config objects in the DB. """ - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) trials1 = [ exp_storage.new_trial(config1), exp_storage.new_trial(config1), @@ -46,7 +48,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, assert trials1[0].tunable_config_id == trials1[1].tunable_config_id assert trials1[0].tunable_config_id == trials1[2].tunable_config_id - config2 = tunable_groups.copy().assign({'idle': 'halt'}) + config2 = tunable_groups.copy().assign({"idle": "halt"}) trials2 = [ exp_storage.new_trial(config2), exp_storage.new_trial(config2), @@ -63,7 +65,10 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, ] assert len(pending_ids) == 6 assert len(set(pending_ids)) == 2 - assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id} + assert set(pending_ids) == { + trials1[0].tunable_config_id, + trials2[0].tunable_config_id, + } def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index 04f4f18ae3..c56efa0031 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -22,8 +22,9 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: return set(t.trial_id for t in trials) -def test_schedule_trial(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_schedule_trial( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """ Schedule several trials for future execution and retrieve them later at certain timestamps. """ @@ -39,13 +40,16 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Schedule 1 hour in the future: trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) # Schedule 2 hours in the future: - trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) + trial_2h = exp_storage.new_trial( + tunable_groups, timestamp + timedelta_1hr * 2, config + ) # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1min, running=False) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -53,7 +57,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 1 hour: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr, running=False) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -62,7 +67,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -84,7 +90,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_1h.trial_id, trial_2h.trial_id, @@ -92,7 +99,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run OR running within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -101,11 +109,15 @@ def test_schedule_trial(exp_storage: Storage.Experiment, } # Mark some trials completed after 2 minutes: - trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) + trial_now1.update( + Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0} + ) trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) # Another one completes after 2 hours: - trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) + trial_1h.update( + Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0} + ) # Check that three trials have completed so far: (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() @@ -114,7 +126,9 @@ def test_schedule_trial(exp_storage: Storage.Experiment, assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] # Get only trials completed after trial_now2: - (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) + (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( + last_trial_id=trial_now2.trial_id + ) assert trial_ids == [trial_1h.trial_id] assert len(trial_configs) == len(trial_scores) == 1 assert trial_status == [Status.SUCCEEDED] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index 855c6cd861..e1f033fae9 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -20,7 +20,9 @@ # pylint: disable=redefined-outer-name -def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]: +def zoned_telemetry_data( + zone_info: Optional[tzinfo], +) -> List[Tuple[datetime, str, Any]]: """ Mock telemetry data for the trial. @@ -31,18 +33,21 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st """ timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) - return sorted([ - (timestamp1, "cpu_load", 10.1), - (timestamp1, "memory", 20), - (timestamp1, "setup", "prod"), - (timestamp2, "cpu_load", 30.1), - (timestamp2, "memory", 40), - (timestamp2, "setup", "prod"), - ]) + return sorted( + [ + (timestamp1, "cpu_load", 10.1), + (timestamp1, "memory", 20), + (timestamp1, "setup", "prod"), + (timestamp2, "cpu_load", 30.1), + (timestamp2, "memory", 40), + (timestamp2, "setup", "prod"), + ] + ) -def _telemetry_str(data: List[Tuple[datetime, str, Any]] - ) -> List[Tuple[datetime, str, Optional[str]]]: +def _telemetry_str( + data: List[Tuple[datetime, str, Any]] +) -> List[Tuple[datetime, str, Optional[str]]]: """ Convert telemetry values to strings. """ @@ -51,10 +56,12 @@ def _telemetry_str(data: List[Tuple[datetime, str, Any]] @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry( + storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """ Make sure update_telemetry() and load_telemetry() methods work. """ @@ -62,7 +69,9 @@ def test_update_telemetry(storage: Storage, trial = exp_storage.new_trial(tunable_groups) assert exp_storage.load_telemetry(trial.trial_id) == [] - trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data) + trial.update_telemetry( + Status.RUNNING, datetime.now(origin_zone_info), telemetry_data + ) assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data) # Also check that the TrialData telemetry looks right. @@ -73,9 +82,11 @@ def test_update_telemetry(storage: Storage, @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry_twice( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """ Make sure update_telemetry() call is idempotent. """ diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 3b57222822..a3333acd2b 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -10,8 +10,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_trial_data_tunable_config_data(exp_data: ExperimentData, - tunable_groups: TunableGroups) -> None: +def test_trial_data_tunable_config_data( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """ Check expected return values for TunableConfigData. """ @@ -29,16 +30,18 @@ def test_trial_metadata(exp_data: ExperimentData) -> None: """ Check expected return values for TunableConfigData metadata. """ - assert exp_data.objectives == {'score': 'min'} - for (trial_id, trial) in exp_data.trials.items(): + assert exp_data.objectives == {"score": "min"} + for trial_id, trial in exp_data.trials.items(): assert trial.metadata_dict == { - 'opt_target_0': 'score', - 'opt_direction_0': 'min', - 'trial_number': trial_id, + "opt_target_0": "score", + "opt_direction_0": "min", + "trial_number": trial_id, } -def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None: +def test_trial_data_no_tunables_config_data( + exp_no_tunables_data: ExperimentData, +) -> None: """ Check expected return values for TunableConfigData. """ @@ -48,13 +51,14 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, - mixed_numerics_tunable_groups: TunableGroups) -> None: + mixed_numerics_exp_data: ExperimentData, + mixed_numerics_tunable_groups: TunableGroups, +) -> None: """ Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) config = trial.tunable_config.config_dict - for (tunable, _group) in mixed_numerics_tunable_groups: + for tunable, _group in mixed_numerics_tunable_groups: assert isinstance(config[tunable.name], tunable.dtype) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index d08b26e92d..987b1a75b2 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -16,10 +16,19 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: trial_id = 1 trial = exp_data.trials[trial_id] tunable_config_trial_group = trial.tunable_config_trial_group - assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id + assert ( + tunable_config_trial_group.experiment_id + == exp_data.experiment_id + == trial.experiment_id + ) assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id assert tunable_config_trial_group.tunable_config == trial.tunable_config - assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group + assert ( + tunable_config_trial_group + == next( + iter(tunable_config_trial_group.trials.values()) + ).tunable_config_trial_group + ) def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: @@ -49,7 +58,9 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) # And so on ... -def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: +def test_tunable_config_trial_group_results_df( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 expected_group_id = 4 @@ -58,15 +69,38 @@ def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable # We shouldn't have the results for the other configs, just this one. expected_count = CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_count - assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count + assert ( + len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) + == expected_count + ) assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 - assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count - assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 + assert ( + len( + results_df[ + (results_df["tunable_config_trial_group_id"] == expected_group_id) + ] + ) + == expected_count + ) + assert ( + len( + results_df[ + (results_df["tunable_config_trial_group_id"] != expected_group_id) + ] + ) + == 0 + ) assert len(results_df["trial_id"].unique()) == expected_count obj_target = next(iter(exp_data.objectives)) - assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count + assert ( + len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) + == expected_count + ) (tunable, _covariant_group) = next(iter(tunable_groups)) - assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count + assert ( + len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) + == expected_count + ) def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: @@ -76,8 +110,16 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] trials = tunable_config_trial_group.trials assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT - assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id - for trial in trials.values()) - assert all(trial.tunable_config_id == tunable_config_id - for trial in tunable_config_trial_group.trials.values()) - assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] + assert all( + trial.tunable_config_trial_group.tunable_config_trial_group_id + == expected_group_id + for trial in trials.values() + ) + assert all( + trial.tunable_config_id == tunable_config_id + for trial in tunable_config_trial_group.trials.values() + ) + assert ( + exp_data.trials[expected_group_id] + == tunable_config_trial_group.trials[expected_group_id] + ) diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index fa947610da..2aba200955 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -24,7 +24,9 @@ ] -@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism") +@pytest.mark.skipif( + sys.platform == "win32", reason="TZ environment variable is a UNIXism" +) @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: @@ -45,4 +47,6 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: if cmd.returncode != 0: print(cmd.stdout.decode()) print(cmd.stderr.decode()) - raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'") + raise AssertionError( + f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'" + ) diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 822547b1da..8329b51bd0 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -119,24 +119,26 @@ def mixed_numerics_tunable_groups() -> TunableGroups: tunable_groups : TunableGroups A new TunableGroups object for testing. """ - tunables = TunableGroups({ - "mix-numerics": { - "cost": 1, - "params": { - "int": { - "description": "An integer", - "type": "int", - "default": 0, - "range": [0, 100], + tunables = TunableGroups( + { + "mix-numerics": { + "cost": 1, + "params": { + "int": { + "description": "An integer", + "type": "int", + "default": 0, + "range": [0, 100], + }, + "float": { + "description": "A float", + "type": "float", + "default": 0, + "range": [0, 1], + }, }, - "float": { - "description": "A float", - "type": "float", - "default": 0, - "range": [0, 1], - }, - } - }, - }) + }, + } + ) tunables.reset() return tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index 95de20d9b8..878471b59e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -25,12 +25,15 @@ def tunable_categorical() -> Tunable: tunable : Tunable An instance of a categorical Tunable. """ - return Tunable("vmSize", { - "description": "Azure VM size", - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - }) + return Tunable( + "vmSize", + { + "description": "Azure VM size", + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + }, + ) @pytest.fixture @@ -43,13 +46,16 @@ def tunable_int() -> Tunable: tunable : Tunable An instance of an integer Tunable. """ - return Tunable("kernel_sched_migration_cost_ns", { - "description": "Cost of migrating the thread to another core", - "type": "int", - "default": 40000, - "range": [0, 500000], - "special": [-1] # Special value outside of the range - }) + return Tunable( + "kernel_sched_migration_cost_ns", + { + "description": "Cost of migrating the thread to another core", + "type": "int", + "default": 40000, + "range": [0, 500000], + "special": [-1], # Special value outside of the range + }, + ) @pytest.fixture @@ -62,9 +68,12 @@ def tunable_float() -> Tunable: tunable : Tunable An instance of a float Tunable. """ - return Tunable("chaos_monkey_prob", { - "description": "Probability of spontaneous VM shutdown", - "type": "float", - "default": 0.01, - "range": [0, 1] - }) + return Tunable( + "chaos_monkey_prob", + { + "description": "Probability of spontaneous VM shutdown", + "type": "float", + "default": 0.01, + "range": [0, 1], + }, + ) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index 0e910f3761..e8b3e6b4cc 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -38,7 +38,7 @@ def test_tunable_categorical_types() -> None: "values": ["a", "b", "c"], "default": "a", }, - } + }, } } tunable_groups = TunableGroups(tunable_params) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index 58bb0368b1..b29c3a1b9e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -14,6 +14,7 @@ # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py + def test_tunable_int_size_props() -> None: """Test tunable int size properties""" tunable = Tunable( @@ -22,7 +23,8 @@ def test_tunable_int_size_props() -> None: "type": "int", "range": [1, 5], "default": 3, - }) + }, + ) assert tunable.span == 4 assert tunable.cardinality == 5 expected = [1, 2, 3, 4, 5] @@ -38,7 +40,8 @@ def test_tunable_float_size_props() -> None: "type": "float", "range": [1.5, 5], "default": 3, - }) + }, + ) assert tunable.span == 3.5 assert tunable.cardinality == np.inf assert tunable.quantized_values is None @@ -53,7 +56,8 @@ def test_tunable_categorical_size_props() -> None: "type": "categorical", "values": ["a", "b", "c"], "default": "a", - }) + }, + ) with pytest.raises(AssertionError): _ = tunable.span assert tunable.cardinality == 3 @@ -70,8 +74,9 @@ def test_tunable_quantized_int_size_props() -> None: "type": "int", "range": [100, 1000], "default": 100, - "quantization": 100 - }) + "quantization": 100, + }, + ) assert tunable.span == 900 assert tunable.cardinality == 10 expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] @@ -83,12 +88,8 @@ def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties""" tunable = Tunable( name="test", - config={ - "type": "float", - "range": [0, 1], - "default": 0, - "quantization": .1 - }) + config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1}, + ) assert tunable.span == 1 assert tunable.cardinality == 11 expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 6a91b14016..407998b3a4 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -28,7 +28,7 @@ def test_tunable_int_name_lt(tunable_int: Tunable) -> None: Tests that the __lt__ operator works as expected. """ tunable_int_2 = tunable_int.copy() - tunable_int_2._name = "aaa" # pylint: disable=protected-access + tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int @@ -38,7 +38,8 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: """ tunable_categorical_2 = tunable_categorical.copy() new_value = [ - x for x in tunable_categorical.categories + x + for x in tunable_categorical.categories if x != tunable_categorical.category and x is not None ][0] assert tunable_categorical.category is not None @@ -59,7 +60,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_dog = Tunable( name="same-name", @@ -67,7 +68,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": [None, "doggo"], "default": None, - } + }, ) assert tunable_dog < tunable_cat @@ -82,7 +83,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_int = Tunable( name="same-name", @@ -90,7 +91,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "int", "range": [1, 3], "default": 2, - } + }, ) assert tunable_cat < tunable_int @@ -101,7 +102,7 @@ def test_tunable_lt_different_object(tunable_int: Tunable) -> None: """ assert (tunable_int < "foo") is False with pytest.raises(TypeError): - assert "foo" < tunable_int # type: ignore[operator] + assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index f2da3ba60e..d2ab29f27d 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -18,7 +18,9 @@ def test_tunable_name() -> None: """ with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names - Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) + Tunable( + name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0} + ) def test_categorical_required_params() -> None: @@ -34,7 +36,7 @@ def test_categorical_required_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights() -> None: @@ -50,7 +52,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.weights == [25, 25, 50] @@ -68,7 +70,7 @@ def test_categorical_weights_wrong_count() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights_wrong_values() -> None: @@ -85,7 +87,7 @@ def test_categorical_weights_wrong_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_wrong_params() -> None: @@ -102,7 +104,7 @@ def test_categorical_wrong_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_disallow_special_values() -> None: @@ -119,7 +121,7 @@ def test_categorical_disallow_special_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_tunable_disallow_repeats() -> None: @@ -127,37 +129,50 @@ def test_categorical_tunable_disallow_repeats() -> None: Disallow duplicate values in categorical tunables. """ with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "foo"], - "default": "foo", - }) + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "foo"], + "default": "foo", + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) -def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None: +def test_numerical_tunable_disallow_null_default( + tunable_type: TunableValueTypeName, +) -> None: """ Disallow null values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": None, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": None, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) -def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None: +def test_numerical_tunable_disallow_out_of_range( + tunable_type: TunableValueTypeName, +) -> None: """ Disallow out of range values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": 11, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": 11, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -166,12 +181,15 @@ def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> N Disallow values param for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "values": ["foo", "bar"], - "default": 0, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "values": ["foo", "bar"], + "default": 0, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -188,7 +206,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -205,7 +223,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(AssertionError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -222,7 +240,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -241,7 +259,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [0] assert tunable.weights == [0.1] assert tunable.range_weight == 0.9 @@ -261,7 +279,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.quantization == 10 assert not tunable.is_log @@ -280,7 +298,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.is_log @@ -299,7 +317,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -319,7 +337,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok assert tunable.range_weight == 90 @@ -342,7 +360,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -361,7 +379,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -380,7 +398,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -398,7 +416,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -418,7 +436,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -436,7 +454,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_bad_type() -> None: @@ -452,4 +470,4 @@ def test_bad_type() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test_bad_type', config=config) + Tunable(name="test_bad_type", config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index deffcb6a46..e8817319ab 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -17,14 +17,15 @@ def test_categorical_distribution() -> None: Try to instantiate a categorical tunable with distribution specified. """ with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "baz"], - "distribution": { - "type": "uniform" + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "baz"], + "distribution": {"type": "uniform"}, + "default": "foo", }, - "default": "foo" - }) + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -32,14 +33,15 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N """ Create a numeric Tunable with explicit uniform distribution. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "uniform" + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "uniform"}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.is_numerical assert tunable.distribution == "uniform" assert not tunable.distribution_params @@ -50,18 +52,15 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No """ Create a numeric Tunable with explicit Gaussian distribution specified. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "normal", - "params": { - "mu": 0, - "sigma": 1.0 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "normal" assert tunable.distribution_params == {"mu": 0, "sigma": 1.0} @@ -71,18 +70,15 @@ def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None """ Create a numeric Tunable with explicit Beta distribution specified. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "beta", - "params": { - "alpha": 2, - "beta": 5 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "beta" assert tunable.distribution_params == {"alpha": 2, "beta": 5} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index c6fb5670f0..eb73b34d12 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -10,7 +10,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: +def test_tunable_group_indexing( + tunable_groups: TunableGroups, tunable_categorical: Tunable +) -> None: """ Check that various types of indexing work for the tunable group. """ @@ -20,7 +22,9 @@ def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categoric # NOTE: we reassign the tunable_categorical here since they come from # different fixtures so are technically different objects. - (tunable_categorical, covariant_group) = tunable_groups.get_tunable(tunable_categorical.name) + (tunable_categorical, covariant_group) = tunable_groups.get_tunable( + tunable_categorical.name + ) assert tunable_groups.get_tunable(tunable_categorical)[0] == tunable_categorical assert tunable_categorical in covariant_group @@ -40,7 +44,9 @@ def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categoric assert covariant_group[tunable_categorical.name] == tunable_categorical.value # Check that we can assign a new value by index. - new_value = [x for x in tunable_categorical.categories if x != tunable_categorical.value][0] + new_value = [ + x for x in tunable_categorical.categories if x != tunable_categorical.value + ][0] tunable_groups[tunable_categorical] = new_value assert tunable_groups[tunable_categorical] == new_value assert tunable_groups[tunable_categorical.name] == new_value diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 55a485e951..186de4acfa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -14,4 +14,4 @@ def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: Check that the subgroup() method returns only a selection of tunable parameters. """ tunables = tunable_groups.subgroup(["provision"]) - assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} + assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 73e3a12caa..0dfbdd2acd 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -36,37 +36,39 @@ @pytest.mark.parametrize("param_type", ["int", "float"]) -@pytest.mark.parametrize("distr_name,distr_params", [ - ("normal", {"mu": 0.0, "sigma": 1.0}), - ("beta", {"alpha": 2, "beta": 5}), - ("uniform", {}), -]) -def test_convert_numerical_distributions(param_type: str, - distr_name: DistributionName, - distr_params: dict) -> None: +@pytest.mark.parametrize( + "distr_name,distr_params", + [ + ("normal", {"mu": 0.0, "sigma": 1.0}), + ("beta", {"alpha": 2, "beta": 5}), + ("uniform", {}), + ], +) +def test_convert_numerical_distributions( + param_type: str, distr_name: DistributionName, distr_params: dict +) -> None: """ Convert a numerical Tunable with explicit distribution to ConfigSpace. """ tunable_name = "x" - tunable_groups = TunableGroups({ - "tunable_group": { - "cost": 1, - "params": { - tunable_name: { - "type": param_type, - "range": [0, 100], - "special": [-1, 0], - "special_weights": [0.1, 0.2], - "range_weight": 0.7, - "distribution": { - "type": distr_name, - "params": distr_params - }, - "default": 0 - } + tunable_groups = TunableGroups( + { + "tunable_group": { + "cost": 1, + "params": { + tunable_name: { + "type": param_type, + "range": [0, 100], + "special": [-1, 0], + "special_weights": [0.1, 0.2], + "range_weight": 0.7, + "distribution": {"type": distr_name, "params": distr_params}, + "default": 0, + } + }, } } - }) + ) (tunable, _group) = tunable_groups.get_tunable(tunable_name) assert tunable.distribution == distr_name @@ -82,5 +84,5 @@ def test_convert_numerical_distributions(param_type: str, cs_param = space[tunable_name] assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) - for (key, val) in distr_params.items(): + for key, val in distr_params.items(): assert getattr(cs_param, key) == val diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 78e91fd25e..c92187a3e7 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -38,17 +38,23 @@ def configuration_space() -> ConfigurationSpace: configuration_space : ConfigurationSpace A new ConfigurationSpace object for testing. """ - (kernel_sched_migration_cost_ns_special, - kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns") - - spaces = ConfigurationSpace(space={ - "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - "idle": ["halt", "mwait", "noidle"], - "kernel_sched_migration_cost_ns": (0, 500000), - kernel_sched_migration_cost_ns_special: [-1, 0], - kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE], - "kernel_sched_latency_ns": (0, 1000000000), - }) + (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = ( + special_param_names("kernel_sched_migration_cost_ns") + ) + + spaces = ConfigurationSpace( + space={ + "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + "idle": ["halt", "mwait", "noidle"], + "kernel_sched_migration_cost_ns": (0, 500000), + kernel_sched_migration_cost_ns_special: [-1, 0], + kernel_sched_migration_cost_ns_type: [ + TunableValueKind.SPECIAL, + TunableValueKind.RANGE, + ], + "kernel_sched_latency_ns": (0, 1000000000), + } + ) # NOTE: FLAML requires distribution to be uniform spaces["vmSize"].default_value = "Standard_B4ms" @@ -60,18 +66,27 @@ def configuration_space() -> ConfigurationSpace: spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5) spaces["kernel_sched_latency_ns"].default_value = 2000000 - spaces.add_condition(EqualsCondition( - spaces[kernel_sched_migration_cost_ns_special], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL)) - spaces.add_condition(EqualsCondition( - spaces["kernel_sched_migration_cost_ns"], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE)) + spaces.add_condition( + EqualsCondition( + spaces[kernel_sched_migration_cost_ns_special], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.SPECIAL, + ) + ) + spaces.add_condition( + EqualsCondition( + spaces["kernel_sched_migration_cost_ns"], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.RANGE, + ) + ) return spaces def _cmp_tunable_hyperparameter_categorical( - tunable: Tunable, space: ConfigurationSpace) -> None: + tunable: Tunable, space: ConfigurationSpace +) -> None: """ Check if categorical Tunable and ConfigSpace Hyperparameter actually match. """ @@ -82,7 +97,8 @@ def _cmp_tunable_hyperparameter_categorical( def _cmp_tunable_hyperparameter_numerical( - tunable: Tunable, space: ConfigurationSpace) -> None: + tunable: Tunable, space: ConfigurationSpace +) -> None: """ Check if integer Tunable and ConfigSpace Hyperparameter actually match. """ @@ -130,12 +146,13 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) - for (tunable, _group) in tunable_groups: + for tunable, _group in tunable_groups: _CMP_FUNC[tunable.type](tunable, space) def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace +) -> None: """ Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index cbccd6bfe1..2f7790602f 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -19,12 +19,14 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: that don't exist in the TunableGroups object. """ with pytest.raises(KeyError): - tunable_groups.assign({ - "vmSize": "Standard_B2ms", - "idle": "mwait", - "UnknownParam_1": 1, - "UnknownParam_2": "invalid-value" - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2ms", + "idle": "mwait", + "UnknownParam_1": 1, + "UnknownParam_2": "invalid-value", + } + ) def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: @@ -106,7 +108,7 @@ def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: Check str to int coercion. """ tunable_int.value = "10" - assert tunable_int.value == 10 # type: ignore[comparison-overlap] + assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special @@ -115,7 +117,7 @@ def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: Check str to float coercion. """ tunable_float.value = "0.5" - assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special @@ -149,12 +151,12 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) - categorical_tunable = Tunable(name='categorical_test', config=config) + categorical_tunable = Tunable(name="categorical_test", config=config) assert categorical_tunable assert categorical_tunable.category == "foo" categorical_tunable.value = None assert categorical_tunable.value is None - assert categorical_tunable.value != 'None' + assert categorical_tunable.value != "None" assert categorical_tunable.category is None @@ -165,7 +167,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): - tunable_int.numerical_value = None # type: ignore[assignment] + tunable_int.numerical_value = None # type: ignore[assignment] def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: @@ -175,7 +177,7 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): - tunable_float.numerical_value = None # type: ignore[assignment] + tunable_float.numerical_value = None # type: ignore[assignment] def test_tunable_assign_special(tunable_int: Tunable) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index 672b16ab73..cb41f7f7d8 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -17,42 +17,44 @@ def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: tunables within each covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: - tunables_other = TunableGroups({ - "kernel": { - "cost": 1, - "params": { - "kernel_sched_latency_ns": { - "type": "int", - "default": 2000000, - "range": [0, 1000000000] + tunables_other = TunableGroups( + { + "kernel": { + "cost": 1, + "params": { + "kernel_sched_latency_ns": { + "type": "int", + "default": 2000000, + "range": [0, 1000000000], + }, + "kernel_sched_migration_cost_ns": { + "type": "int", + "default": -1, + "range": [0, 500000], + "special": [-1], + }, }, - "kernel_sched_migration_cost_ns": { - "type": "int", - "default": -1, - "range": [0, 500000], - "special": [-1] - } - } - }, - "boot": { - "cost": 300, - "params": { - "idle": { - "type": "categorical", - "default": "halt", - "values": ["halt", "mwait", "noidle"] - } - } - }, - "provision": { - "cost": 1000, - "params": { - "vmSize": { - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - } - } - }, - }) + }, + "boot": { + "cost": 300, + "params": { + "idle": { + "type": "categorical", + "default": "halt", + "values": ["halt", "mwait", "noidle"], + } + }, + }, + "provision": { + "cost": 1000, + "params": { + "vmSize": { + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + } + }, + }, + } + ) assert str(tunable_groups) == str(tunables_other) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 4191f37d89..3433f4a735 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -10,7 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups __all__ = [ - 'Tunable', - 'TunableValue', - 'TunableGroups', + "Tunable", + "TunableValue", + "TunableGroups", ] diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index fee4fd5841..797510a087 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -93,10 +93,12 @@ def __eq__(self, other: object) -> bool: return False # TODO: May need to provide logic to relax the equality check on the # tunables (e.g. "compatible" vs. "equal"). - return (self._name == other._name and - self._cost == other._cost and - self._is_updated == other._is_updated and - self._tunables == other._tunables) + return ( + self._name == other._name + and self._cost == other._cost + and self._is_updated == other._is_updated + and self._tunables == other._tunables + ) def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ @@ -234,7 +236,11 @@ def __contains__(self, tunable: Union[str, Tunable]) -> bool: def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._is_updated |= self.get_tunable(tunable).update(value) return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 1ebd70dfa4..1886d09597 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -107,7 +107,9 @@ def __init__(self, name: str, config: TunableDict): config : dict Python dict that represents a Tunable (e.g., deserialized from JSON) """ - if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema + if ( + not isinstance(name, str) or "!" in name + ): # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name self._type: TunableValueTypeName = config["type"] # required @@ -115,7 +117,9 @@ def __init__(self, name: str, config: TunableDict): raise ValueError(f"Invalid parameter type: {self._type}") self._description = config.get("description") self._default = config["default"] - self._default = self.dtype(self._default) if self._default is not None else self._default + self._default = ( + self.dtype(self._default) if self._default is not None else self._default + ) self._values = config.get("values") if self._values: self._values = [str(v) if v is not None else v for v in self._values] @@ -154,7 +158,9 @@ def _sanity_check(self) -> None: else: raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}") if not self.is_valid(self.default): - raise ValueError(f"Invalid default value for tunable {self}: {self.default}") + raise ValueError( + f"Invalid default value for tunable {self}: {self.default}" + ) def _sanity_check_categorical(self) -> None: """ @@ -163,11 +169,17 @@ def _sanity_check_categorical(self) -> None: # pylint: disable=too-complex assert self.is_categorical if not (self._values and isinstance(self._values, collections.abc.Iterable)): - raise ValueError(f"Must specify values for the categorical type tunable {self}") + raise ValueError( + f"Must specify values for the categorical type tunable {self}" + ) if self._range is not None: - raise ValueError(f"Range must be None for the categorical type tunable {self}") + raise ValueError( + f"Range must be None for the categorical type tunable {self}" + ) if len(set(self._values)) != len(self._values): - raise ValueError(f"Values must be unique for the categorical type tunable {self}") + raise ValueError( + f"Values must be unique for the categorical type tunable {self}" + ) if self._special: raise ValueError(f"Categorical tunable cannot have special values: {self}") if self._range_weight is not None: @@ -175,9 +187,13 @@ def _sanity_check_categorical(self) -> None: if self._log is not None: raise ValueError(f"Categorical tunable cannot have log parameter: {self}") if self._quantization is not None: - raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}") + raise ValueError( + f"Categorical tunable cannot have quantization parameter: {self}" + ) if self._distribution is not None: - raise ValueError(f"Categorical parameters do not support `distribution`: {self}") + raise ValueError( + f"Categorical parameters do not support `distribution`: {self}" + ) if self._weights: if len(self._weights) != len(self._values): raise ValueError(f"Must specify weights for all values: {self}") @@ -191,21 +207,31 @@ def _sanity_check_numerical(self) -> None: # pylint: disable=too-complex,too-many-branches assert self.is_numerical if self._values is not None: - raise ValueError(f"Values must be None for the numerical type tunable {self}") + raise ValueError( + f"Values must be None for the numerical type tunable {self}" + ) if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: raise ValueError(f"Invalid range for tunable {self}: {self._range}") if self._quantization is not None: if self.dtype == int: if not isinstance(self._quantization, int): - raise ValueError(f"Quantization of a int param should be an int: {self}") + raise ValueError( + f"Quantization of a int param should be an int: {self}" + ) if self._quantization <= 1: raise ValueError(f"Number of quantization points is <= 1: {self}") if self.dtype == float: if not isinstance(self._quantization, (float, int)): - raise ValueError(f"Quantization of a float param should be a float or int: {self}") + raise ValueError( + f"Quantization of a float param should be a float or int: {self}" + ) if self._quantization <= 0: raise ValueError(f"Number of quantization points is <= 0: {self}") - if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: + if self._distribution is not None and self._distribution not in { + "uniform", + "normal", + "beta", + }: raise ValueError(f"Invalid distribution: {self}") if self._distribution_params and self._distribution is None: raise ValueError(f"Must specify the distribution: {self}") @@ -217,7 +243,9 @@ def _sanity_check_numerical(self) -> None: if any(w < 0 for w in self._weights + [self._range_weight]): raise ValueError(f"All weights must be non-negative: {self}") elif self._range_weight is not None: - raise ValueError(f"Must specify both weights and range_weight or none: {self}") + raise ValueError( + f"Must specify both weights and range_weight or none: {self}" + ) def __repr__(self) -> str: """ @@ -251,12 +279,14 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Tunable): return False return bool( - self._name == other._name and - self._type == other._type and - self._current_value == other._current_value + self._name == other._name + and self._type == other._type + and self._current_value == other._current_value ) - def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements + def __lt__( + self, other: object + ) -> bool: # pylint: disable=too-many-return-statements """ Compare the two Tunable objects. We mostly need this to create a canonical list of tunable objects when hashing a TunableGroup. @@ -336,18 +366,33 @@ def value(self, value: TunableValue) -> TunableValue: assert value is not None coerced_value = self.dtype(value) except Exception: - _LOG.error("Impossible conversion: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Impossible conversion: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, + ) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: - _LOG.error("Loss of precision: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Loss of precision: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, + ) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): - _LOG.error("Invalid assignment: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Invalid assignment: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, + ) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") self._current_value = coerced_value @@ -392,7 +437,9 @@ def is_valid(self, value: TunableValue) -> bool: if isinstance(value, (int, float)): return self.in_range(value) or value in self._special else: - raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}") + raise ValueError( + f"Invalid value type for tunable {self}: {value}={type(value)}" + ) else: raise ValueError(f"Invalid parameter type: {self._type}") @@ -403,10 +450,10 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: Return False if the tunable or value is categorical or None. """ return ( - isinstance(value, (float, int)) and - self.is_numerical and - self._range is not None and - bool(self._range[0] <= value <= self._range[1]) + isinstance(value, (float, int)) + and self.is_numerical + and self._range is not None + and bool(self._range[0] <= value <= self._range[1]) ) @property @@ -626,12 +673,19 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: # Be sure to return python types instead of numpy types. cardinality = self.cardinality assert isinstance(cardinality, int) - return (float(x) for x in np.linspace(start=num_range[0], - stop=num_range[1], - num=cardinality, - endpoint=True)) + return ( + float(x) + for x in np.linspace( + start=num_range[0], + stop=num_range[1], + num=cardinality, + endpoint=True, + ) + ) assert self.type == "int", f"Unhandled tunable type: {self}" - return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) + return range( + int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1) + ) @property def cardinality(self) -> Union[int, float]: @@ -706,7 +760,9 @@ def categories(self) -> List[Optional[str]]: return self._values @property - def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: + def values( + self, + ) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: """ Gets the categories or quantized values for this tunable. diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 0bd58c8269..b48da6fccb 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -30,9 +30,11 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = ( + {} + ) # Index (Tunable id -> CovariantTunableGroup) self._tunable_groups: Dict[str, CovariantTunableGroup] = {} - for (name, group_config) in config.items(): + for name, group_config in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) def __bool__(self) -> bool: @@ -81,11 +83,15 @@ def _add_group(self, group: CovariantTunableGroup) -> None: ---------- group : CovariantTunableGroup """ - assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" + assert ( + group.name not in self._tunable_groups + ), f"Duplicate covariant tunable group name {group.name} in {self}" self._tunable_groups[group.name] = group for tunable in group.get_tunables(): if tunable.name in self._index: - raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") + raise ValueError( + f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" + ) self._index[tunable.name] = group def merge(self, tunables: "TunableGroups") -> "TunableGroups": @@ -119,8 +125,10 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": # Check that there's no overlap in the tunables. # But allow for differing current values. if not self._tunable_groups[group.name].equals_defaults(group): - raise ValueError(f"Overlapping covariant tunable group name {group.name} " + - "in {self._tunable_groups[group.name]} and {tunables}") + raise ValueError( + f"Overlapping covariant tunable group name {group.name} " + + "in {self._tunable_groups[group.name]} and {tunables}" + ) return self def __repr__(self) -> str: @@ -132,10 +140,17 @@ def __repr__(self) -> str: string : str A human-readable version of the TunableGroups. """ - return "{ " + ", ".join( - f"{group.name}::{tunable}" - for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) - for tunable in sorted(group._tunables.values())) + " }" + return ( + "{ " + + ", ".join( + f"{group.name}::{tunable}" + for group in sorted( + self._tunable_groups.values(), key=lambda g: (-g.cost, g.name) + ) + for tunable in sorted(group._tunables.values()) + ) + + " }" + ) def __contains__(self, tunable: Union[str, Tunable]) -> bool: """ @@ -151,13 +166,17 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: """ Update the current value of a single tunable parameter. """ # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._index[name][name] = value return self._index[name][name] @@ -171,9 +190,13 @@ def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, Non An iterator over all tunables in all groups. Each element is a 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. """ - return ((group.get_tunable(name), group) for (name, group) in self._index.items()) + return ( + (group.get_tunable(name), group) for (name, group) in self._index.items() + ) - def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: + def get_tunable( + self, tunable: Union[str, Tunable] + ) -> Tuple[Tunable, CovariantTunableGroup]: """ Access the entire Tunable (not just its value) and its covariant group. Throw KeyError if the tunable is not found. @@ -228,12 +251,17 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": tunables = TunableGroups() for name in group_names: if name not in self._tunable_groups: - raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") + raise KeyError( + f"Unknown covariant group name '{name}' in tunable group {self}" + ) tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values(self, group_names: Optional[Iterable[str]] = None, - into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: + def get_param_values( + self, + group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None, + ) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -272,8 +300,10 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: is_updated : bool True if any of the specified tunable groups has been updated, False otherwise. """ - return any(self._tunable_groups[name].is_updated() - for name in (group_names or self.get_covariant_group_names())) + return any( + self._tunable_groups[name].is_updated() + for name in (group_names or self.get_covariant_group_names()) + ) def is_defaults(self) -> bool: """ @@ -285,7 +315,9 @@ def is_defaults(self) -> bool: """ return all(group.is_defaults() for group in self._tunable_groups.values()) - def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": + def restore_defaults( + self, group_names: Optional[Iterable[str]] = None + ) -> "TunableGroups": """ Restore all tunable parameters to their default values. @@ -299,7 +331,7 @@ def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "Tuna self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].restore_defaults() return self @@ -317,7 +349,7 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].reset_is_updated() return self diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 531988be97..2892543e5f 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -42,7 +42,9 @@ from mlos_bench.storage.base_storage import Storage # BaseTypeVar is a generic with a constraint of the three base classes. -BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") +BaseTypeVar = TypeVar( + "BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage" +) BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"] @@ -71,8 +73,12 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> return dest -def merge_parameters(*, dest: dict, source: Optional[dict] = None, - required_keys: Optional[Iterable[str]] = None) -> dict: +def merge_parameters( + *, + dest: dict, + source: Optional[dict] = None, + required_keys: Optional[Iterable[str]] = None, +) -> dict: """ Merge the source config dict into the destination config. Pick from the source configs *ONLY* the keys that are already present @@ -132,8 +138,9 @@ def path_join(*args: str, abs_path: bool = False) -> str: return os.path.normpath(path).replace("\\", "/") -def prepare_class_load(config: dict, - global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load( + config: dict, global_config: Optional[Dict[str, Any]] = None +) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -155,8 +162,11 @@ def prepare_class_load(config: dict, merge_parameters(dest=class_config, source=global_config) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", + class_name, + json.dumps(class_config, indent=2), + ) return (class_name, class_config) @@ -187,8 +197,9 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". -def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, - *args: Any, **kwargs: Any) -> BaseTypeVar: +def instantiate_from_config( + base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any +) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -220,7 +231,9 @@ def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, return ret -def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: +def check_required_params( + config: Mapping[str, Any], required_params: Iterable[str] +) -> None: """ Check if all required parameters are present in the configuration. Raise ValueError if any of the parameters are missing. @@ -238,7 +251,8 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s if missing_params: raise ValueError( "The following parameters must be provided in the configuration" - + f" or as command line arguments: {missing_params}") + + f" or as command line arguments: {missing_params}" + ) def get_git_info(path: str = __file__) -> Tuple[str, str, str]: @@ -257,11 +271,14 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]: """ dirname = os.path.dirname(path) git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() + ["git", "-C", dirname, "remote", "get-url", "origin"], text=True + ).strip() git_commit = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "HEAD"], text=True + ).strip() git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) return (git_repo, git_commit, rel_path.replace("\\", "/")) @@ -317,7 +334,9 @@ def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]: return None if value is None else func(value) -def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: +def utcify_timestamp( + timestamp: datetime, *, origin: Literal["utc", "local"] +) -> datetime: """ Augment a timestamp with zoneinfo if missing and convert it to UTC. @@ -355,7 +374,9 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> raise ValueError(f"Invalid origin: {origin}") -def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: +def utcify_nullable_timestamp( + timestamp: Optional[datetime], *, origin: Literal["utc", "local"] +) -> Optional[datetime]: """ A nullable version of utcify_timestamp. """ @@ -367,7 +388,9 @@ def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal[ _MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) -def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: +def datetime_parser( + datetime_col: pandas.Series, *, origin: Literal["utc", "local"] +) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. @@ -401,7 +424,7 @@ def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "loca new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) assert new_datetime_col.dt.tz is not None # And convert it to UTC. - new_datetime_col = new_datetime_col.dt.tz_convert('UTC') + new_datetime_col = new_datetime_col.dt.tz_convert("UTC") if new_datetime_col.isna().any(): raise ValueError(f"Invalid date format in the data: {datetime_col}") if new_datetime_col.le(_MIN_TS).any(): diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index 96d3d2b6bf..f8acae8c02 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 27d844c35b..b2090424a6 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -21,21 +21,24 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") + warning( + f"setuptools_scm failed to find git version, using version from version.py: {e}" + ) # A simple routine to read and adjust the README.md for this module into a format @@ -47,62 +50,72 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } -extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass - # Additional tools for extra functionality. - 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], - 'ssh': ['asyncssh'], - 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], - 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], - 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], - 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it. - # Transitive extra_requires from mlos-core. - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac'], -} +extra_requires: Dict[str, List[str]] = ( + { # pylint: disable=consider-using-namedtuple-or-dataclass + # Additional tools for extra functionality. + "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], + "ssh": ["asyncssh"], + "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], + "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], + "storage-sql-postgres": ["sqlalchemy", "psycopg2"], + "storage-sql-sqlite": [ + "sqlalchemy" + ], # sqlite3 comes with python, so we don't need to install it. + # Transitive extra_requires from mlos-core. + "flaml": ["flaml[blendsearch]"], + "smac": ["smac"], + } +) # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', - 'pytest-lazy-fixtures', - 'pytest-docker', - 'fasteners', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", + "pytest-lazy-fixtures", + "pytest-docker", + "fasteners", ] setup( version=VERSION, install_requires=[ - 'mlos-core==' + VERSION, - 'requests', - 'json5', - 'jsonschema>=4.18.0', 'referencing>=0.29.1', + "mlos-core==" + VERSION, + "requests", + "json5", + "jsonschema>=4.18.0", + "referencing>=0.29.1", 'importlib_resources;python_version<"3.10"', - ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it. + ] + + extra_requires[ + "storage-sql-sqlite" + ], # NOTE: For now sqlite is a fallback storage backend, so we always install it. extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'), + **_get_long_desc_from_readme( + "https://github.com/microsoft/MLOS/tree/main/mlos_bench" + ), ) diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index 086002af62..b3e248e407 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -18,12 +18,12 @@ from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ - 'SpaceAdapterType', - 'OptimizerFactory', - 'BaseOptimizer', - 'RandomOptimizer', - 'FlamlOptimizer', - 'SmacOptimizer', + "SpaceAdapterType", + "OptimizerFactory", + "BaseOptimizer", + "RandomOptimizer", + "FlamlOptimizer", + "SmacOptimizer", ] @@ -45,7 +45,7 @@ class OptimizerType(Enum): # ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) # To address this, we add a test for complete coverage of the enum. ConcreteOptimizer = TypeVar( - 'ConcreteOptimizer', + "ConcreteOptimizer", RandomOptimizer, FlamlOptimizer, SmacOptimizer, @@ -60,13 +60,15 @@ class OptimizerFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, - optimizer_kwargs: Optional[dict] = None, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, + optimizer_kwargs: Optional[dict] = None, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteOptimizer: # type: ignore[type-var] """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. @@ -107,7 +109,7 @@ def create(*, parameter_space=parameter_space, optimization_targets=optimization_targets, space_adapter=space_adapter, - **optimizer_kwargs + **optimizer_kwargs, ) return optimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index 5f32219988..d4f59dfa52 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -12,6 +12,6 @@ from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer __all__ = [ - 'BaseBayesianOptimizer', - 'SmacOptimizer', + "BaseBayesianOptimizer", + "SmacOptimizer", ] diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 76ff0d9b3a..9d3bcabcb2 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -19,8 +19,9 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): """Abstract base class defining the interface for Bayesian optimization.""" @abstractmethod - def surrogate_predict(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). Parameters @@ -31,11 +32,12 @@ def surrogate_predict(self, *, configs: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def acquisition_function(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """Invokes the acquisition function from this Bayesian optimizer for the given configuration. Parameters @@ -46,4 +48,4 @@ def acquisition_function(self, *, configs: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 9d8d2a0347..4364f4c172 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -29,19 +29,22 @@ class SmacOptimizer(BaseBayesianOptimizer): Wrapper class for SMAC based Bayesian optimization. """ - def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - seed: Optional[int] = 0, - run_name: Optional[str] = None, - output_directory: Optional[str] = None, - max_trials: int = 100, - n_random_init: Optional[int] = None, - max_ratio: Optional[float] = None, - use_default_config: bool = False, - n_random_probability: float = 0.1): + def __init__( + self, + *, # pylint: disable=too-many-locals,too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + seed: Optional[int] = 0, + run_name: Optional[str] = None, + output_directory: Optional[str] = None, + max_trials: int = 100, + n_random_init: Optional[int] = None, + max_ratio: Optional[float] = None, + use_default_config: bool = False, + n_random_probability: float = 0.1, + ): """ Instantiate a new SMAC optimizer wrapper. @@ -124,7 +127,9 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory( + ignore_cleanup_errors=True + ) # Argument added in Python 3.10 except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -146,8 +151,12 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments seed=seed or -1, # if -1, SMAC will generate a random seed internally n_workers=1, # Use a single thread for evaluating trials ) - intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1) - config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1) + intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( + scenario, max_config_calls=1 + ) + config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( + scenario, retrain_after=1 + ) # TODO: When bulk registering prior configs to rewarm the optimizer, # there is a way to inform SMAC's initial design that we have @@ -158,27 +167,27 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments # See Also: #488 initial_design_args: Dict[str, Union[list, int, float, Scenario]] = { - 'scenario': scenario, + "scenario": scenario, # Workaround a bug in SMAC that sets a default arg to a mutable # value that can cause issues when multiple optimizers are # instantiated with the use_default_config option within the same # process that use different ConfigSpaces so that the second # receives the default config from both as an additional config. - 'additional_configs': [] + "additional_configs": [], } if n_random_init is not None: - initial_design_args['n_configs'] = n_random_init + initial_design_args["n_configs"] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - 'Number of random initial configs (%d) is ' + - 'greater than 25%% of max_trials (%d). ' + - 'Consider setting max_ratio to avoid SMAC overriding n_random_init.', + "Number of random initial configs (%d) is " + + "greater than 25%% of max_trials (%d). " + + "Consider setting max_ratio to avoid SMAC overriding n_random_init.", n_random_init, max_trials, ) if max_ratio is not None: assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0 - initial_design_args['max_ratio'] = max_ratio + initial_design_args["max_ratio"] = max_ratio # Use the default InitialDesign from SMAC. # (currently SBOL instead of LatinHypercube due to better uniformity @@ -190,7 +199,9 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments # design when generated a random_design for itself via the # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 - random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed) + random_design = ProbabilityRandomDesign( + probability=n_random_probability, seed=scenario.seed + ) self.base_optimizer = Optimizer_Smac( scenario, @@ -200,7 +211,8 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights), + scenario, objective_weights=self._objective_weights + ), overwrite=True, logging_level=False, # Use the existing logger ) @@ -241,10 +253,16 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None """ # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 - raise RuntimeError('This function should never be called.') - - def _register(self, *, configs: pd.DataFrame, - scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + raise RuntimeError("This function should never be called.") + + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -268,20 +286,30 @@ def _register(self, *, configs: pd.DataFrame, ) if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) # Register each trial (one-by-one) - for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()): + for config, (_i, score) in zip( + self._to_configspace_configs(configs=configs), scores.iterrows() + ): # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed)) - value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) + config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) + ) + value = TrialValue( + cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS + ) self.base_optimizer.tell(info, value, save=False) # Save optimizer once we register all configs self.base_optimizer.optimizer.save() - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -303,62 +331,99 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr ) if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) trial: TrialInfo = self.base_optimizer.ask() trial.config.is_valid_configuration() self.optimizer_parameter_space.check_configuration(trial.config) assert trial.config.config_space == self.optimizer_parameter_space self.trial_info_map[trial.config] = trial - config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys())) + config_df = pd.DataFrame( + [trial.config], columns=list(self.optimizer_parameter_space.keys()) + ) return config_df, None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() - def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: from smac.utils.configspace import ( convert_configurations_to_array, # pylint: disable=import-outside-toplevel ) if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) if self._space_adapter and not isinstance(self._space_adapter, IdentityAdapter): - raise NotImplementedError("Space adapter not supported for surrogate_predict.") + raise NotImplementedError( + "Space adapter not supported for surrogate_predict." + ) # pylint: disable=protected-access if len(self._observations) <= self.base_optimizer._initial_design._n_configs: raise RuntimeError( - 'Surrogate model can make predictions *only* after all initial points have been evaluated ' + - f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}') + "Surrogate model can make predictions *only* after all initial points have been evaluated " + + f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" + ) if self.base_optimizer._config_selector._model is None: - raise RuntimeError('Surrogate model is not yet trained') + raise RuntimeError("Surrogate model is not yet trained") - config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs)) - mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array) - return mean_predictions.reshape(-1,) + config_array: npt.NDArray = convert_configurations_to_array( + self._to_configspace_configs(configs=configs) + ) + mean_predictions, _ = self.base_optimizer._config_selector._model.predict( + config_array + ) + return mean_predictions.reshape( + -1, + ) - def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) if self._space_adapter: raise NotImplementedError() # pylint: disable=protected-access if self.base_optimizer._config_selector._acquisition_function is None: - raise RuntimeError('Acquisition function is not yet initialized') + raise RuntimeError("Acquisition function is not yet initialized") cs_configs: list = self._to_configspace_configs(configs=configs) - return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,) + return self.base_optimizer._config_selector._acquisition_function( + cs_configs + ).reshape( + -1, + ) def cleanup(self) -> None: - if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None: + if ( + hasattr(self, "_temp_output_directory") + and self._temp_output_directory is not None + ): self._temp_output_directory.cleanup() self._temp_output_directory = None - def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace.Configuration]: + def _to_configspace_configs( + self, *, configs: pd.DataFrame + ) -> List[ConfigSpace.Configuration]: """Convert a dataframe of configs to a list of ConfigSpace configs. Parameters @@ -372,6 +437,8 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. List of ConfigSpace configs. """ return [ - ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) - for (_, config) in configs.astype('O').iterrows() + ConfigSpace.Configuration( + self.optimizer_parameter_space, values=config.to_dict() + ) + for (_, config) in configs.astype("O").iterrows() ] diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 273c89eecc..638613c43d 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -33,13 +33,16 @@ class FlamlOptimizer(BaseOptimizer): # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" - def __init__(self, *, # pylint: disable=too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - low_cost_partial_config: Optional[dict] = None, - seed: Optional[int] = None): + def __init__( + self, + *, # pylint: disable=too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + low_cost_partial_config: Optional[dict] = None, + seed: Optional[int] = None, + ): """ Create an MLOS wrapper for FLAML. @@ -82,14 +85,22 @@ def __init__(self, *, # pylint: disable=too-many-arguments configspace_to_flaml_space, ) - self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) + self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space( + self.optimizer_parameter_space + ) self.low_cost_partial_config = low_cost_partial_config self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {} self._suggested_config: Optional[dict] - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -107,21 +118,34 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, Not Yet Implemented. """ if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) if metadata is not None: - warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring metadata {list(metadata.columns)}", + UserWarning, + ) - for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()): + for (_, config), (_, score) in zip( + configs.astype("O").iterrows(), scores.iterrows() + ): cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict()) + self.optimizer_parameter_space, values=config.to_dict() + ) if cs_config in self.evaluated_samples: warn(f"Configuration {config} was already registered", UserWarning) self.evaluated_samples[cs_config] = EvaluatedSample( config=config.to_dict(), - score=float(np.average(score.astype(float), weights=self._objective_weights)), + score=float( + np.average(score.astype(float), weights=self._objective_weights) + ), ) - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -140,12 +164,20 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr Not implemented. """ if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) config: dict = self._get_next_config() return pd.DataFrame(config, index=[0]), None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: @@ -200,16 +232,14 @@ def _get_next_config(self) -> dict: dict(normalize_config(self.optimizer_parameter_space, conf)) for conf in self.evaluated_samples ] - evaluated_rewards = [ - s.score for s in self.evaluated_samples.values() - ] + evaluated_rewards = [s.score for s in self.evaluated_samples.values()] # Warm start FLAML optimizer self._suggested_config = None tune.run( self._target_function, config=self.flaml_parameter_space, - mode='min', + mode="min", metric=self._METRIC_NAME, points_to_evaluate=points_to_evaluate, evaluated_rewards=evaluated_rewards, @@ -218,6 +248,6 @@ def _get_next_config(self) -> dict: verbose=0, ) if self._suggested_config is None: - raise RuntimeError('FLAML did not produce a suggestion') + raise RuntimeError("FLAML did not produce a suggestion") return self._suggested_config # type: ignore[unreachable] diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 4ab9db5a2f..8e80de16f1 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -24,11 +24,14 @@ class BaseOptimizer(metaclass=ABCMeta): Optimizer abstract base class defining the basic interface. """ - def __init__(self, *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None): + def __init__( + self, + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + ): """ Create a new instance of the base optimizer. @@ -44,21 +47,37 @@ def __init__(self, *, The space adapter class to employ for parameter space transformations. """ self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space - self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \ - parameter_space if space_adapter is None else space_adapter.target_parameter_space - - if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space: - raise ValueError("Given parameter space differs from the one given to space adapter") + self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = ( + parameter_space + if space_adapter is None + else space_adapter.target_parameter_space + ) + + if ( + space_adapter is not None + and space_adapter.orig_parameter_space != parameter_space + ): + raise ValueError( + "Given parameter space differs from the one given to space adapter" + ) self._optimization_targets = optimization_targets self._objective_weights = objective_weights - if objective_weights is not None and len(objective_weights) != len(optimization_targets): - raise ValueError("Number of weights must match the number of optimization targets") + if objective_weights is not None and len(objective_weights) != len( + optimization_targets + ): + raise ValueError( + "Number of weights must match the number of optimization targets" + ) self._space_adapter: Optional[BaseSpaceAdapter] = space_adapter - self._observations: List[Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]] = [] + self._observations: List[ + Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]] + ] = [] self._has_context: Optional[bool] = None - self._pending_observations: List[Tuple[pd.DataFrame, Optional[pd.DataFrame]]] = [] + self._pending_observations: List[ + Tuple[pd.DataFrame, Optional[pd.DataFrame]] + ] = [] def __repr__(self) -> str: return f"{self.__class__.__name__}(space_adapter={self.space_adapter})" @@ -68,8 +87,14 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: """Get the space adapter instance (if any).""" return self._space_adapter - def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. Parameters @@ -87,29 +112,39 @@ def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, """ # Do some input validation. assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(scores.columns) == set(self._optimization_targets), \ - "Mismatched optimization targets." - assert self._has_context is None or self._has_context ^ (context is None), \ - "Context must always be added or never be added." - assert len(configs) == len(scores), \ - "Mismatched number of configs and scores." + assert set(scores.columns) == set( + self._optimization_targets + ), "Mismatched optimization targets." + assert self._has_context is None or self._has_context ^ ( + context is None + ), "Context must always be added or never be added." + assert len(configs) == len(scores), "Mismatched number of configs and scores." if context is not None: - assert len(configs) == len(context), \ - "Mismatched number of configs and context." - assert configs.shape[1] == len(self.parameter_space.values()), \ - "Mismatched configuration shape." + assert len(configs) == len( + context + ), "Mismatched number of configs and context." + assert configs.shape[1] == len( + self.parameter_space.values() + ), "Mismatched configuration shape." self._observations.append((configs, scores, context)) self._has_context = context is not None if self._space_adapter: configs = self._space_adapter.inverse_transform(configs) - assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \ - "Mismatched configuration shape after inverse transform." + assert configs.shape[1] == len( + self.optimizer_parameter_space.values() + ), "Mismatched configuration shape after inverse transform." return self._register(configs=configs, scores=scores, context=context) @abstractmethod - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -122,10 +157,11 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def suggest(self, *, context: Optional[pd.DataFrame] = None, - defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def suggest( + self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. @@ -143,24 +179,31 @@ def suggest(self, *, context: Optional[pd.DataFrame] = None, Pandas dataframe with a single row. Column names are the parameter names. """ if defaults: - configuration = config_to_dataframe(self.parameter_space.get_default_configuration()) + configuration = config_to_dataframe( + self.parameter_space.get_default_configuration() + ) metadata = None if self.space_adapter is not None: configuration = self.space_adapter.inverse_transform(configuration) else: configuration, metadata = self._suggest(context=context) - assert len(configuration) == 1, \ - "Suggest must return a single configuration." - assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \ - "Optimizer suggested a configuration that does not match the expected parameter space." + assert ( + len(configuration) == 1 + ), "Suggest must return a single configuration." + assert set(configuration.columns).issubset( + set(self.optimizer_parameter_space) + ), "Optimizer suggested a configuration that does not match the expected parameter space." if self._space_adapter: configuration = self._space_adapter.transform(configuration) - assert set(configuration.columns).issubset(set(self.parameter_space)), \ - "Space adapter produced a configuration that does not match the expected parameter space." + assert set(configuration.columns).issubset( + set(self.parameter_space) + ), "Space adapter produced a configuration that does not match the expected parameter space." return configuration, metadata @abstractmethod - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -176,12 +219,16 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr metadata : Optional[pd.DataFrame] The metadata associated with the given configuration used for evaluations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs as "pending". That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. @@ -195,9 +242,11 @@ def register_pending(self, *, configs: pd.DataFrame, metadata : Optional[pd.DataFrame] Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_observations( + self, + ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Returns the observations as a triplet of DataFrames (config, score, context). @@ -208,13 +257,23 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data """ if len(self._observations) == 0: raise ValueError("No observations registered yet.") - configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True) - scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True) - contexts = pd.concat([pd.DataFrame() if context is None else context - for _, _, context in self._observations]).reset_index(drop=True) + configs = pd.concat( + [config for config, _, _ in self._observations] + ).reset_index(drop=True) + scores = pd.concat([score for _, score, _ in self._observations]).reset_index( + drop=True + ) + contexts = pd.concat( + [ + pd.DataFrame() if context is None else context + for _, _, context in self._observations + ] + ).reset_index(drop=True) return (configs, scores, contexts if len(contexts.columns) > 0 else None) - def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_best_observations( + self, *, n_max: int = 1 + ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, context). Default is N=1. The columns are ordered in ASCENDING order of the optimization targets. @@ -233,9 +292,14 @@ def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.Dat if len(self._observations) == 0: raise ValueError("No observations registered yet.") (configs, scores, contexts) = self.get_observations() - idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index - return (configs.loc[idx], scores.loc[idx], - None if contexts is None else contexts.loc[idx]) + idx = scores.nsmallest( + n_max, columns=self._optimization_targets, keep="first" + ).index + return ( + configs.loc[idx], + scores.loc[idx], + None if contexts is None else contexts.loc[idx], + ) def cleanup(self) -> None: """ @@ -253,7 +317,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: j = 0 for param in self.optimizer_parameter_space.values(): if isinstance(param, ConfigSpace.CategoricalHyperparameter): - for (offset, val) in enumerate(param.choices): + for offset, val in enumerate(param.choices): if config[i][j + offset] == 1: df_dict[param.name].append(val) break diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index 0af785ef20..f1ce489b28 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -24,8 +24,14 @@ class RandomOptimizer(BaseOptimizer): The parameter space to optimize. """ - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Doesn't do anything on the RandomOptimizer except storing configs for logging. @@ -45,12 +51,20 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, Not Yet Implemented. """ if context is not None: - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) if metadata is not None: - warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning) + warn( + f"Not Implemented: Ignoring context {list(metadata.columns)}", + UserWarning, + ) # should we pop them from self.pending_observations? - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -70,10 +84,23 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr """ if context is not None: # not sure how that works here? - warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) - return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None - - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + warn( + f"Not Implemented: Ignoring context {list(context.columns)}", + UserWarning, + ) + return ( + pd.DataFrame( + dict(self.optimizer_parameter_space.sample_configuration()), index=[0] + ), + None, + ) + + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() # self._pending_observations.append((configs, context)) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 2e2f585590..73e7f37dc3 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -15,8 +15,8 @@ from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter __all__ = [ - 'IdentityAdapter', - 'LlamaTuneAdapter', + "IdentityAdapter", + "LlamaTuneAdapter", ] @@ -35,7 +35,7 @@ class SpaceAdapterType(Enum): # ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType]) # To address this, we add a test for complete coverage of the enum. ConcreteSpaceAdapter = TypeVar( - 'ConcreteSpaceAdapter', + "ConcreteSpaceAdapter", IdentityAdapter, LlamaTuneAdapter, ) @@ -47,10 +47,12 @@ class SpaceAdapterFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteSpaceAdapter: # type: ignore[type-var] """ Create a new space adapter instance, given the parameter space and potential space adapter options. @@ -75,8 +77,7 @@ def create(*, space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, - **space_adapter_kwargs + orig_parameter_space=parameter_space, **space_adapter_kwargs ) return space_adapter diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 6c3a86fc8a..cc7b22b708 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -22,7 +22,9 @@ class BaseSpaceAdapter(metaclass=ABCMeta): """ def __init__(self, *, orig_parameter_space: ConfigSpace.ConfigurationSpace): - self._orig_parameter_space: ConfigSpace.ConfigurationSpace = orig_parameter_space + self._orig_parameter_space: ConfigSpace.ConfigurationSpace = ( + orig_parameter_space + ) self._random_state = orig_parameter_space.random def __repr__(self) -> str: @@ -46,7 +48,7 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: """ Target parameter space that is fed to the underlying optimizer. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @@ -64,7 +66,7 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Pandas dataframe with a single row, containing the translated configuration. Column names are the parameter names of the original parameter space. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: @@ -84,4 +86,4 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: Dataframe of the translated configurations / parameters. The columns are the parameter names of the target parameter space and the rows are the configurations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 4d3a925cbc..9c98b772ec 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -19,7 +19,9 @@ from mlos_core.util import normalize_config -class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes +class LlamaTuneAdapter( + BaseSpaceAdapter +): # pylint: disable=too-many-instance-attributes """ Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. @@ -28,18 +30,23 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- DEFAULT_NUM_LOW_DIMS = 16 """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" - DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 + DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 """Default percentage of bias for each special parameter value""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 """Default number of (max) unique values of each parameter, when space discretization is used""" - def __init__(self, *, - orig_parameter_space: ConfigSpace.ConfigurationSpace, - num_low_dims: int = DEFAULT_NUM_LOW_DIMS, - special_param_values: Optional[dict] = None, - max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, - use_approximate_reverse_mapping: bool = False): + def __init__( + self, + *, + orig_parameter_space: ConfigSpace.ConfigurationSpace, + num_low_dims: int = DEFAULT_NUM_LOW_DIMS, + special_param_values: Optional[dict] = None, + max_unique_values_per_param: Optional[ + int + ] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, + use_approximate_reverse_mapping: bool = False, + ): """ Create a space adapter that employs LlamaTune's techniques. @@ -58,7 +65,9 @@ def __init__(self, *, super().__init__(orig_parameter_space=orig_parameter_space) if num_low_dims >= len(orig_parameter_space): - raise ValueError("Number of target config space dimensions should be less than those of original config space.") + raise ValueError( + "Number of target config space dimensions should be less than those of original config space." + ) # Validate input special param values dict special_param_values = special_param_values or {} @@ -79,7 +88,9 @@ def __init__(self, *, self._sigma_vector = self._random_state.choice([-1, 1], num_orig_dims) # Used to retrieve the low-dim point, given the high-dim one - self._suggested_configs: Dict[ConfigSpace.Configuration, ConfigSpace.Configuration] = {} + self._suggested_configs: Dict[ + ConfigSpace.Configuration, ConfigSpace.Configuration + ] = {} self._pinv_matrix: npt.NDArray self._use_approximate_reverse_mapping = use_approximate_reverse_mapping @@ -90,9 +101,10 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] - for (_, config) in configurations.astype('O').iterrows(): + for _, config in configurations.astype("O").iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict()) + self.orig_parameter_space, values=config.to_dict() + ) target_config = self._suggested_configs.get(configuration, None) # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection @@ -100,16 +112,22 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart. if target_config is None: # Inherently it is not supported to register points, which were not suggested by the optimizer. - if configuration == self.orig_parameter_space.get_default_configuration(): + if ( + configuration + == self.orig_parameter_space.get_default_configuration() + ): # Default configuration should always be registerable. pass elif not self._use_approximate_reverse_mapping: - raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. " - "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " - "previously by the optimizer can be registered.") + raise ValueError( + f"{repr(configuration)}\n" + "The above configuration was not suggested by the optimizer. " + "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " + "previously by the optimizer can be registered." + ) # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. - if getattr(self, '_pinv_matrix', None) is None: + if getattr(self, "_pinv_matrix", None) is None: self._try_generate_approx_inverse_mapping() # Replace NaNs with zeros for inactive hyperparameters @@ -118,29 +136,43 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # NOTE: applying special value biasing is not possible vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) - target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector) + target_config = ConfigSpace.Configuration( + self.target_parameter_space, vector=target_config_vector + ) target_configurations.append(target_config) - return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys())) + return pd.DataFrame( + target_configurations, columns=list(self.target_parameter_space.keys()) + ) def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: if len(configuration) != 1: - raise ValueError("Configuration dataframe must contain exactly 1 row. " - f"Found {len(configuration)} rows.") + raise ValueError( + "Configuration dataframe must contain exactly 1 row. " + f"Found {len(configuration)} rows." + ) target_values_dict = configuration.iloc[0].to_dict() - target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict) + target_configuration = ConfigSpace.Configuration( + self.target_parameter_space, values=target_values_dict + ) orig_values_dict = self._transform(target_values_dict) - orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict) + orig_configuration = normalize_config( + self.orig_parameter_space, orig_values_dict + ) # Add to inverse dictionary -- needed for registering the performance later self._suggested_configs[orig_configuration] = target_configuration - return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) + return pd.DataFrame( + [list(orig_configuration.values())], columns=list(orig_configuration.keys()) + ) - def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: + def _construct_low_dim_space( + self, num_low_dims: int, max_unique_values_per_param: Optional[int] + ) -> None: """Constructs the low-dimensional parameter (potentially discretized) search space. Parameters @@ -156,7 +188,9 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para q_scaler = None if max_unique_values_per_param is None: hyperparameters = [ - ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1) + ConfigSpace.UniformFloatHyperparameter( + name=f"dim_{idx}", lower=-1, upper=1 + ) for idx in range(num_low_dims) ] else: @@ -164,7 +198,9 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. hyperparameters = [ - ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param) + ConfigSpace.UniformIntegerHyperparameter( + name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param + ) for idx in range(num_low_dims) ] @@ -177,8 +213,12 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para self._q_scaler = q_scaler # Construct low-dimensional parameter search space - config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) - config_space.random = self._random_state # use same random state as in original parameter space + config_space = ConfigSpace.ConfigurationSpace( + name=self.orig_parameter_space.name + ) + config_space.random = ( + self._random_state + ) # use same random state as in original parameter space config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -209,17 +249,21 @@ def _transform(self, configuration: dict) -> dict: for idx in range(len(original_parameters)) ] # Scale parameter values to [0, 1] - original_config_values = self._config_scaler.transform([original_config_values])[0] + original_config_values = self._config_scaler.transform( + [original_config_values] + )[0] original_config = {} for param, norm_value in zip(original_parameters, original_config_values): # Clip value to force it to fall in [0, 1] # NOTE: HeSBO projection ensures that theoretically but due to # floating point ops nuances this is not always guaranteed - value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name + value = max( + 0.0, min(1.0, norm_value) + ) # pylint: disable=redefined-loop-name if isinstance(param, ConfigSpace.CategoricalHyperparameter): - index = int(value * len(param.choices)) # truncate integer part + index = int(value * len(param.choices)) # truncate integer part index = max(0, min(len(param.choices) - 1, index)) # NOTE: potential rounding here would be unfair to first & last values orig_value = param.choices[index] @@ -227,16 +271,20 @@ def _transform(self, configuration: dict) -> dict: if param.name in self._special_param_values_dict: value = self._special_param_value_scaler(param, value) - orig_value = param._transform(value) # pylint: disable=protected-access + orig_value = param._transform(value) # pylint: disable=protected-access orig_value = max(param.lower, min(param.upper, orig_value)) else: - raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.") + raise NotImplementedError( + "Only Categorical, Integer, and Float hyperparameters are currently supported." + ) original_config[param.name] = orig_value return original_config - def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: + def _special_param_value_scaler( + self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float + ) -> float: """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. Parameters @@ -255,17 +303,20 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara special_values_list = self._special_param_values_dict[param.name] # Check if input value corresponds to some special value - perc_sum = 0. + perc_sum = 0.0 ret: float for special_value, biasing_perc in special_values_list: perc_sum += biasing_perc if input_value < perc_sum: - ret = param._inverse_transform(special_value) # pylint: disable=protected-access + ret = param._inverse_transform( + special_value + ) # pylint: disable=protected-access return ret # Scale input value uniformly to non-special values - ret = param._inverse_transform( # pylint: disable=protected-access - param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access + ret = param._inverse_transform( # pylint: disable=protected-access + param._transform_scalar((input_value - perc_sum) / (1 - perc_sum)) + ) # pylint: disable=protected-access return ret # pylint: disable=too-complex,too-many-branches @@ -294,46 +345,79 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non hyperparameter = self.orig_parameter_space[param] if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter): - raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. " - "Only Integer Hyperparameters are currently supported.") + raise NotImplementedError( + error_prefix + f"Parameter '{param}' is not supported. " + "Only Integer Hyperparameters are currently supported." + ) if isinstance(value, int): # User specifies a single special value -- default biasing percentage is used - tuple_list = [(value, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE)] + tuple_list = [ + (value, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) + ] elif isinstance(value, tuple) and [type(v) for v in value] == [int, float]: # User specifies both special value and biasing percentage tuple_list = [value] elif isinstance(value, list) and value: if all(isinstance(t, int) for t in value): # User specifies list of special values - tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value] - elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value): + tuple_list = [ + (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) + for v in value + ] + elif all( + isinstance(t, tuple) and [type(v) for v in t] == [int, float] + for t in value + ): # User specifies list of tuples; each tuple defines the special value and the biasing percentage tuple_list = value else: - raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. " - f"Special value list should contain either integers, or (special value, biasing %) tuples.") + raise ValueError( + error_prefix + + f"Invalid format in value list for parameter '{param}'. " + f"Special value list should contain either integers, or (special value, biasing %) tuples." + ) else: - raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " - "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.") + raise ValueError( + error_prefix + + f"Invalid format for parameter '{param}'. Dict value should be " + "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples." + ) # Are user-specified special values valid? - if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.") + if not all( + hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list + ): + raise ValueError( + error_prefix + + f"One (or more) special values are outside of parameter '{param}' value domain." + ) # Are user-provided special values unique? if len(set(v for v, _ in tuple_list)) != len(tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.") + raise ValueError( + error_prefix + + f"One (or more) special values are defined more than once for parameter '{param}'." + ) # Are biasing percentages valid? if not all(0 < perc < 1 for _, perc in tuple_list): - raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: " - "i.e., fall outside (0, 1) range.") + raise ValueError( + error_prefix + + f"One (or more) biasing percentages for parameter '{param}' are invalid: " + "i.e., fall outside (0, 1) range." + ) total_percentage = sum(perc for _, perc in tuple_list) - if total_percentage >= 1.: - raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.") + if total_percentage >= 1.0: + raise ValueError( + error_prefix + + f"Total special values percentage for parameter '{param}' surpass 100%." + ) # ... and reasonable? if total_percentage >= 0.5: - warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning) + warn( + f"Total special values percentage for parameter '{param}' exceeds 50%.", + UserWarning, + ) sanitized_dict[param] = tuple_list @@ -355,9 +439,12 @@ def _try_generate_approx_inverse_mapping(self) -> None: pinv, ) - warn("Trying to register a configuration that was not previously suggested by the optimizer. " + - "This inverse configuration transformation is typically not supported. " + - "However, we will try to register this configuration using an *experimental* method.", UserWarning) + warn( + "Trying to register a configuration that was not previously suggested by the optimizer. " + + "This inverse configuration transformation is typically not supported. " + + "However, we will try to register this configuration using an *experimental* method.", + UserWarning, + ) orig_space_num_dims = len(list(self.orig_parameter_space.values())) target_space_num_dims = len(list(self.target_parameter_space.values())) @@ -371,5 +458,7 @@ def _try_generate_approx_inverse_mapping(self) -> None: try: self._pinv_matrix = pinv(proj_matrix) except LinAlgError as err: - raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err + raise RuntimeError( + f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}" + ) from err assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims) diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index d6918f9891..4fec0ed242 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -27,7 +27,9 @@ FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] -def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: +def configspace_to_flaml_space( + config_space: ConfigSpace.ConfigurationSpace, +) -> Dict[str, FlamlDomain]: """Converts a ConfigSpace.ConfigurationSpace to dict. Parameters @@ -50,13 +52,23 @@ def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + ) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + 1 + ) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: - raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.") - return flaml.tune.choice(parameter.choices) # TODO: set order? - raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") + raise ValueError( + "FLAML doesn't support categorical parameters with non-uniform probabilities." + ) + return flaml.tune.choice(parameter.choices) # TODO: set order? + raise ValueError( + f"Type of parameter {parameter} ({type(parameter)}) not supported." + ) - return {param.name: _one_parameter_convert(param) for param in config_space.values()} + return { + param.name: _one_parameter_convert(param) for param in config_space.values() + } diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index a8ad146205..99dcbf2b2f 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -21,7 +21,7 @@ from typing_extensions import TypeAlias -T = TypeVar('T') +T = TypeVar("T") def get_all_submodules(pkg: TypeAlias) -> List[str]: @@ -30,7 +30,9 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: Useful for dynamically enumerating subclasses. """ submodules = [] - for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): + for _, submodule_name, _ in walk_packages( + pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None + ): submodules.append(submodule_name) return submodules @@ -41,10 +43,13 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( - s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) + s for c in cls.__subclasses__() for s in _get_all_subclasses(c) + ) -def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: +def get_all_concrete_subclasses( + cls: Type[T], pkg_name: Optional[str] = None +) -> List[Type[T]]: """ Gets a sorted list of all of the concrete subclasses of the given class. Useful for dynamically enumerating expected test cases. @@ -57,5 +62,11 @@ def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> pkg = import_module(pkg_name) submodules = get_all_submodules(pkg) assert submodules - return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], - key=lambda c: (c.__module__, c.__name__)) + return sorted( + [ + subclass + for subclass in _get_all_subclasses(cls) + if not getattr(subclass, "__abstractmethods__", None) + ], + key=lambda c: (c.__module__, c.__name__), + ) diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index c7a94dfcc4..775afa2455 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -17,24 +17,27 @@ @pytest.mark.filterwarnings("error:Not Implemented") -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_context_not_implemented_warning( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Make sure we raise warnings for the functionality that has not been implemented yet. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) suggestion, _metadata = optimizer.suggest() - scores = pd.DataFrame({'score': [1]}) + scores = pd.DataFrame({"score": [1]}) context = pd.DataFrame([["something"]]) with pytest.raises(UserWarning): diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 39231bec5c..5efdbb81cf 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -18,9 +18,13 @@ def configuration_space() -> CS.ConfigurationSpace: # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. - space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) # Add a categorical hyperparameter with 3 possible values. - space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"])) + space.add_hyperparameter( + CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"]) + ) # Add a discrete input dimension between 0 and 10. - space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10)) + space.add_hyperparameter( + CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10) + ) return space diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 725d92fbe9..be2e89137d 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -23,11 +23,13 @@ def data_frame() -> pd.DataFrame: Toy data frame corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.DataFrame({ - 'y': ['a', 'b', 'c'], - 'x': [0.1, 0.2, 0.3], - 'z': [1, 5, 8], - }) + return pd.DataFrame( + { + "y": ["a", "b", "c"], + "x": [0.1, 0.2, 0.3], + "z": [1, 5, 8], + } + ) @pytest.fixture @@ -36,11 +38,13 @@ def one_hot_data_frame() -> npt.NDArray: One-hot encoding of the `data_frame` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.1, 1.0, 0.0, 0.0, 1.0], - [0.2, 0.0, 1.0, 0.0, 5.0], - [0.3, 0.0, 0.0, 1.0, 8.0], - ]) + return np.array( + [ + [0.1, 1.0, 0.0, 0.0, 1.0], + [0.2, 0.0, 1.0, 0.0, 5.0], + [0.3, 0.0, 0.0, 1.0, 8.0], + ] + ) @pytest.fixture @@ -49,11 +53,13 @@ def series() -> pd.Series: Toy series corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.Series({ - 'y': 'b', - 'x': 0.4, - 'z': 3, - }) + return pd.Series( + { + "y": "b", + "x": 0.4, + "z": 3, + } + ) @pytest.fixture @@ -62,9 +68,11 @@ def one_hot_series() -> npt.NDArray: One-hot encoding of the `series` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.4, 0.0, 1.0, 0.0, 3], - ]) + return np.array( + [ + [0.4, 0.0, 1.0, 0.0, 3], + ] + ) @pytest.fixture @@ -74,48 +82,56 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ return SmacOptimizer( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], ) -def test_to_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_to_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """ Toy problem to test one-hot encoding of dataframe. """ assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) -def test_to_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, one_hot_series: npt.NDArray) -> None: +def test_to_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """ Toy problem to test one-hot encoding of series. """ assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) -def test_from_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_from_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """ Toy problem to test one-hot decoding of dataframe. """ - assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() + assert ( + optimizer._from_1hot(config=one_hot_data_frame).to_dict() + == data_frame.to_dict() + ) -def test_from_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, - one_hot_series: npt.NDArray) -> None: +def test_from_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """ Toy problem to test one-hot decoding of series. """ one_hot_df = optimizer._from_1hot(config=one_hot_series) - assert one_hot_df.shape[0] == 1, f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" + assert ( + one_hot_df.shape[0] == 1 + ), f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" assert one_hot_df.iloc[0].to_dict() == series.to_dict() -def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame) -> None: +def test_round_trip_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame +) -> None: """ Round-trip test for one-hot-encoding and then decoding a data frame. """ @@ -135,17 +151,21 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No assert (series_round_trip.z == series.z).all() -def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, - one_hot_data_frame: npt.NDArray) -> None: +def test_round_trip_reverse_data_frame( + optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray +) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ - round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame)) + round_trip = optimizer._to_1hot( + config=optimizer._from_1hot(config=one_hot_data_frame) + ) assert round_trip == pytest.approx(one_hot_data_frame) -def test_round_trip_reverse_series(optimizer: BaseOptimizer, - one_hot_series: npt.NDArray) -> None: +def test_round_trip_reverse_series( + optimizer: BaseOptimizer, one_hot_series: npt.NDArray +) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 0b9d624a7a..ad9ae51d23 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -20,10 +20,15 @@ _LOG = logging.getLogger(__name__) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt_wrong_weights( + optimizer_class: Type[BaseOptimizer], kwargs: dict +) -> None: """ Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. @@ -31,23 +36,31 @@ def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kw with pytest.raises(ValueError): optimizer_class( parameter_space=CS.ConfigurationSpace(seed=SEED), - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=[1], - **kwargs + **kwargs, ) -@pytest.mark.parametrize(('objective_weights'), [ - [2, 1], - [0.5, 0.5], - None, -]) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt(objective_weights: Optional[List[float]], - optimizer_class: Type[BaseOptimizer], - kwargs: dict) -> None: +@pytest.mark.parametrize( + ("objective_weights"), + [ + [2, 1], + [0.5, 0.5], + None, + ], +) +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt( + objective_weights: Optional[List[float]], + optimizer_class: Type[BaseOptimizer], + kwargs: dict, +) -> None: """ Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. @@ -56,21 +69,25 @@ def test_multi_target_opt(objective_weights: Optional[List[float]], def objective(point: pd.DataFrame) -> pd.DataFrame: # mix of hyperparameters, optimal is to select the highest possible - return pd.DataFrame({ - "main_score": point.x + point.y, - "other_score": point.x ** 2 + point.y ** 2, - }) + return pd.DataFrame( + { + "main_score": point.x + point.y, + "other_score": point.x**2 + point.y**2, + } + ) input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) + CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0) + ) optimizer = optimizer_class( parameter_space=input_space, - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=objective_weights, **kwargs, ) @@ -85,27 +102,28 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y'} + assert set(suggestion.columns) == {"x", "y"} # Check suggestion values are the expected dtype assert isinstance(suggestion.x.iloc[0], np.integer) assert isinstance(suggestion.y.iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. observation = objective(suggestion) assert isinstance(observation, pd.DataFrame) - assert set(observation.columns) == {'main_score', 'other_score'} + assert set(observation.columns) == {"main_score", "other_score"} optimizer.register(configs=suggestion, scores=observation) (best_config, best_score, best_context) = optimizer.get_best_observations() assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'main_score', 'other_score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"main_score", "other_score"} assert best_config.shape == (1, 2) assert best_score.shape == (1, 2) @@ -113,7 +131,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'main_score', 'other_score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"main_score", "other_score"} assert all_configs.shape == (max_iterations, 2) assert all_scores.shape == (max_iterations, 2) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 5fd28ca1ed..c923a4f4bc 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -32,20 +32,24 @@ _LOG.setLevel(logging.DEBUG) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_create_optimizer_and_suggest( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Test that we can create an optimizer and get a suggestion from it. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) assert optimizer is not None @@ -62,11 +66,17 @@ def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace optimizer.register_pending(configs=suggestion, metadata=metadata) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_basic_interface_toy_problem( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Toy problem to test the optimizers. """ @@ -77,17 +87,15 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, if optimizer_class == OptimizerType.SMAC.value: # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. - kwargs['max_trials'] = max_iterations * 2 + kwargs["max_trials"] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: - return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)}) + return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)}) # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) with pytest.raises(ValueError, match="No observations"): @@ -100,12 +108,14 @@ def objective(x: pd.Series) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y', 'z'} + assert set(suggestion.columns) == {"x", "y", "z"} # check that suggestion is in the space - configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict()) + configuration = CS.Configuration( + optimizer.parameter_space, suggestion.iloc[0].to_dict() + ) # Raises an error if outside of configuration space configuration.is_valid_configuration() - observation = objective(suggestion['x']) + observation = objective(suggestion["x"]) assert isinstance(observation, pd.DataFrame) optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -113,8 +123,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y', 'z'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y", "z"} + assert set(best_score.columns) == {"score"} assert best_config.shape == (1, 3) assert best_score.shape == (1, 1) assert best_score.score.iloc[0] < -5 @@ -123,8 +133,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y', 'z'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y", "z"} + assert set(all_scores.columns) == {"score"} assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) @@ -137,27 +147,36 @@ def objective(x: pd.Series) -> pd.DataFrame: assert pred_all.shape == (20,) -@pytest.mark.parametrize(('optimizer_type'), [ - # Enumerate all supported Optimizers - # *[member for member in OptimizerType], - *list(OptimizerType), -]) +@pytest.mark.parametrize( + ("optimizer_type"), + [ + # Enumerate all supported Optimizers + # *[member for member in OptimizerType], + *list(OptimizerType), + ], +) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ - assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member - - -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: + assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member + + +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_create_optimizer_with_factory_method( + configuration_space: CS.ConfigurationSpace, + optimizer_type: Optional[OptimizerType], + kwargs: Optional[dict], +) -> None: """ Test that we can create an optimizer via a factory. """ @@ -166,13 +185,13 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -188,17 +207,25 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat assert myrepr.startswith(optimizer_type.value.__name__) -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - (OptimizerType.SMAC, { - # Test with default config. - 'use_default_config': True, - # 'n_random_init': 10, - }), -]) -def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ( + OptimizerType.SMAC, + { + # Test with default config. + "use_default_config": True, + # 'n_random_init': 10, + }, + ), + ], +) +def test_optimizer_with_llamatune( + optimizer_type: OptimizerType, kwargs: Optional[dict] +) -> None: """ Toy problem to test the optimizers with llamatune space adapter. """ @@ -215,8 +242,12 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=1234) # Add two continuous inputs - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="x", lower=0, upper=3) + ) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="y", lower=0, upper=3) + ) # Initialize an optimizer that uses LlamaTune space adapter space_adapter_kwargs = { @@ -239,7 +270,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: llamatune_optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=llamatune_optimizer_kwargs, space_adapter_type=SpaceAdapterType.LLAMATUNE, @@ -248,16 +279,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Initialize an optimizer that uses the original space optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=optimizer_kwargs, ) assert optimizer is not None assert llamatune_optimizer is not None - assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space + assert ( + optimizer.optimizer_parameter_space + != llamatune_optimizer.optimizer_parameter_space + ) llamatune_n_random_init = 0 - opt_n_random_init = int(kwargs.get('n_random_init', 0)) + opt_n_random_init = int(kwargs.get("n_random_init", 0)) if optimizer_type == OptimizerType.SMAC: assert isinstance(optimizer, SmacOptimizer) assert isinstance(llamatune_optimizer, SmacOptimizer) @@ -278,37 +312,48 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() - _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space + _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( + 3.0, rel=1e-3 + ) # optimizer explores 1-dimensional space observation = objective(suggestion) - llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) + llamatune_optimizer.register( + configs=suggestion, scores=observation, metadata=metadata + ) # Retrieve best observations best_observation = optimizer.get_best_observations() llamatune_best_observation = llamatune_optimizer.get_best_observations() - for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation): + for best_config, best_score, best_context in ( + best_observation, + llamatune_best_observation, + ): assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"score"} (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that - assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \ - best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + assert ( + best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] + or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + ) # Retrieve and check all observations - for (all_configs, all_scores, all_contexts) in ( - optimizer.get_observations(), llamatune_optimizer.get_observations()): + for all_configs, all_scores, all_contexts in ( + optimizer.get_observations(), + llamatune_optimizer.get_observations(), + ): assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"score"} assert len(all_configs) == num_iters assert len(all_scores) == num_iters @@ -320,12 +365,13 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract] - pkg_name='mlos_core') +optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( + BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] +) assert optimizer_subclasses -@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) +@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: """ Test that all optimizer classes are listed in the OptimizerType enum. @@ -334,14 +380,19 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: assert optimizer_class in optimizer_type_classes -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_mixed_numerics_type_input_space_types( + optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] +) -> None: """ Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -355,19 +406,23 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5) + ) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0) + ) if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -381,12 +436,14 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) - assert (suggestion.columns == ['x', 'y']).all() + assert (suggestion.columns == ["x", "y"]).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion['x'].iloc[0], np.integer) - assert isinstance(suggestion['y'].iloc[0], np.floating) + assert isinstance(suggestion["x"].iloc[0], np.integer) + assert isinstance(suggestion["y"].iloc[0], np.floating) # Check that suggestion is in the space - test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + test_configuration = CS.Configuration( + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 37b8aa3a69..13a28d242d 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -20,22 +20,33 @@ def test_identity_adapter() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100)) + CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) adapter = IdentityAdapter(orig_parameter_space=input_space) num_configs = 10 - for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + for sampled_config in input_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # (false positive) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) assert target_config_df.equals(sampled_config_df) - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config orig_config_df = adapter.transform(target_config_df) assert orig_config_df.equals(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) assert orig_config == sampled_config diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 84dcd4e5c0..d0dfcb7691 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -30,34 +30,64 @@ def construct_parameter_space( for idx in range(n_continuous_params): input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64)) + CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64) + ) for idx in range(n_integer_params): input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256)) + CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256) + ) for idx in range(n_categorical_params): input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)])) + CS.CategoricalHyperparameter( + name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)] + ) + ) return input_space -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + # Mix of all three types + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + }, + ) + ] + ), +) +def test_num_low_dims( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's low-to-high space projection method. """ @@ -66,8 +96,7 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N # Number of target parameter space dimensions should be fewer than those of the original space with pytest.raises(ValueError): LlamaTuneAdapter( - orig_parameter_space=input_space, - num_low_dims=len(list(input_space.keys())) + orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys())) ) # Enable only low-dimensional space projections @@ -75,35 +104,53 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N orig_parameter_space=input_space, num_low_dims=num_target_space_dims, special_param_values=None, - max_unique_values_per_param=None + max_unique_values_per_param=None, ) sampled_configs = adapter.target_parameter_space.sample_configuration(size=100) - for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive) + for ( + sampled_config + ) in sampled_configs: # pylint: disable=not-an-iterable # (false positive) # Transform low-dim config to high-dim point/config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid - orig_config = CS.Configuration(input_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + input_space, values=orig_config_df.iloc[0].to_dict() + ) input_space.check_configuration(orig_config) # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config # Try inverse projection (i.e., high-to-low) for previously unseen configs - unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25) - for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) - if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive) + unseen_sampled_configs = adapter.target_parameter_space.sample_configuration( + size=25 + ) + for ( + unseen_sampled_config + ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) + if ( + unseen_sampled_config in sampled_configs + ): # pylint: disable=unsupported-membership-test # (false positive) continue - unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())) + unseen_sampled_config_df = pd.DataFrame( + [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()) + ) with pytest.raises(ValueError): - _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type + _ = adapter.inverse_transform( + unseen_sampled_config_df + ) # pylint: disable=redefined-variable-type def test_special_parameter_values_validation() -> None: @@ -112,15 +159,20 @@ def test_special_parameter_values_validation() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)])) + CS.CategoricalHyperparameter( + name="str", choices=[f"choice_{idx}" for idx in range(5)] + ) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100)) + CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100) + ) # Only UniformIntegerHyperparameters are currently supported with pytest.raises(NotImplementedError): - special_param_values_dict_1 = {'str': 'choice_1'} + special_param_values_dict_1 = {"str": "choice_1"} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -129,7 +181,7 @@ def test_special_parameter_values_validation() -> None: ) with pytest.raises(NotImplementedError): - special_param_values_dict_2 = {'cont': -1} + special_param_values_dict_2 = {"cont": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -138,8 +190,8 @@ def test_special_parameter_values_validation() -> None: ) # Special value should belong to parameter value domain - with pytest.raises(ValueError, match='value domain'): - special_param_values_dict = {'int': -1} + with pytest.raises(ValueError, match="value domain"): + special_param_values_dict = {"int": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -149,15 +201,17 @@ def test_special_parameter_values_validation() -> None: # Invalid dicts; ValueError should be thrown invalid_special_param_values_dicts: List[Dict[str, Any]] = [ - {'int-Q': 0}, # parameter does not exist - {'int': {0: 0.2}}, # invalid definition - {'int': 0.2}, # invalid parameter value - {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) - {'int': [0, 0]}, # duplicate special values - {'int': []}, # empty list - {'int': [{0: 0.2}]}, - {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct - {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values + {"int-Q": 0}, # parameter does not exist + {"int": {0: 0.2}}, # invalid definition + {"int": 0.2}, # invalid parameter value + { + "int": (0.4, 0) + }, # (biasing %, special value) instead of (special value, biasing %) + {"int": [0, 0]}, # duplicate special values + {"int": []}, # empty list + {"int": [{0: 0.2}]}, + {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct + {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values ] for spv_dict in invalid_special_param_values_dicts: with pytest.raises(ValueError): @@ -170,13 +224,13 @@ def test_special_parameter_values_validation() -> None: # Biasing percentage of special value(s) are invalid invalid_special_param_values_dicts = [ - {'int': (0, 1.1)}, # >1 probability - {'int': (0, 0)}, # Zero probability - {'int': (0, -0.1)}, # Negative probability - {'int': (0, 20)}, # 2,000% instead of 20% - {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% - {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100% - {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. + {"int": (0, 1.1)}, # >1 probability + {"int": (0, 0)}, # Zero probability + {"int": (0, -0.1)}, # Negative probability + {"int": (0, 20)}, # 2,000% instead of 20% + {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% + {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100% + {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. ] for spv_dict in invalid_special_param_values_dicts: @@ -189,24 +243,34 @@ def test_special_parameter_values_validation() -> None: ) -def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]: - for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs): +def gen_random_configs( + adapter: LlamaTuneAdapter, num_configs: int +) -> Iterator[CS.Configuration]: + for sampled_config in adapter.target_parameter_space.sample_configuration( + size=num_configs + ): # Transform low-dim config to high-dim config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) yield orig_config -def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex +def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex """ Tests LlamaTune's special parameter values biasing methodology """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100) + ) num_configs = 400 bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE @@ -214,10 +278,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co # Single parameter; single special value special_param_value_dicts: List[Dict[str, Any]] = [ - {'int_1': 0}, - {'int_1': (0, bias_percentage)}, - {'int_1': [0]}, - {'int_1': [(0, bias_percentage)]} + {"int_1": 0}, + {"int_1": (0, bias_percentage)}, + {"int_1": [0]}, + {"int_1": [(0, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -229,13 +293,18 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_value_occurrences = sum( - 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0) - assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences + 1 + for config in gen_random_configs(adapter, num_configs) + if config["int_1"] == 0 + ) + assert (1 - eps) * int( + num_configs * bias_percentage + ) <= special_value_occurrences # Single parameter; multiple special values special_param_value_dicts = [ - {'int_1': [0, 1]}, - {'int_1': [(0, bias_percentage), (1, bias_percentage)]} + {"int_1": [0, 1]}, + {"int_1": [(0, bias_percentage), (1, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -248,18 +317,22 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co special_values_occurrences = {0: 0, 1: 0} for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: + if config["int_1"] == 0: special_values_occurrences[0] += 1 - elif config['int_1'] == 1: + elif config["int_1"] == 1: special_values_occurrences[1] += 1 - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0] - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[1] + assert (1 - eps) * int( + num_configs * bias_percentage + ) <= special_values_occurrences[0] + assert (1 - eps) * int( + num_configs * bias_percentage + ) <= special_values_occurrences[1] # Multiple parameters; multiple special values; different biasing percentage spv_dict = { - 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)], - 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)] + "int_1": [(0, bias_percentage), (1, bias_percentage / 2)], + "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)], } adapter = LlamaTuneAdapter( orig_parameter_space=input_space, @@ -269,24 +342,32 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_values_instances: Dict[str, Dict[int, int]] = { - 'int_1': {0: 0, 1: 0}, - 'int_2': {2: 0, 100: 0}, + "int_1": {0: 0, 1: 0}, + "int_2": {2: 0, 100: 0}, } for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: - special_values_instances['int_1'][0] += 1 - elif config['int_1'] == 1: - special_values_instances['int_1'][1] += 1 - - if config['int_2'] == 2: - special_values_instances['int_2'][2] += 1 - elif config['int_2'] == 100: - special_values_instances['int_2'][100] += 1 - - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2] - assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100] + if config["int_1"] == 0: + special_values_instances["int_1"][0] += 1 + elif config["int_1"] == 1: + special_values_instances["int_1"][1] += 1 + + if config["int_2"] == 2: + special_values_instances["int_2"][2] += 1 + elif config["int_2"] == 100: + special_values_instances["int_2"][100] += 1 + + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances[ + "int_1" + ][0] + assert (1 - eps) * int( + num_configs * bias_percentage / 2 + ) <= special_values_instances["int_1"][1] + assert (1 - eps) * int( + num_configs * bias_percentage / 2 + ) <= special_values_instances["int_2"][2] + assert (1 - eps) * int( + num_configs * bias_percentage * 1.5 + ) <= special_values_instances["int_2"][100] def test_max_unique_values_per_param() -> None: @@ -296,17 +377,25 @@ def test_max_unique_values_per_param() -> None: # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5)) + CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100)) + CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10)) + CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)])) + CS.CategoricalHyperparameter( + name="str_2", choices=[f"choice_{idx}" for idx in range(10)] + ) + ) # Restrict the number of unique parameter values num_configs = 200 @@ -319,7 +408,9 @@ def test_max_unique_values_per_param() -> None: ) # Keep track of unique values generated for each parameter - unique_values_dict: Dict[str, set] = {param: set() for param in list(input_space.keys())} + unique_values_dict: Dict[str, set] = { + param: set() for param in list(input_space.keys()) + } for config in gen_random_configs(adapter, num_configs): for param, value in config.items(): unique_values_dict[param].add(value) @@ -329,23 +420,48 @@ def test_max_unique_values_per_param() -> None: assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + # Mix of all three types + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + }, + ) + ] + ), +) +def test_approx_inverse_mapping( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse. """ @@ -360,9 +476,11 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: use_approximate_reverse_mapping=False, ) - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.raises(ValueError): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) _ = adapter.inverse_transform(sampled_config_df) # Enable low-dimensional space projection *and* reverse mapping @@ -375,41 +493,63 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: ) # Warning should be printed the first time - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.warns(UserWarning): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) # Test inverse transform with 100 random configs for _ in range(100): - sampled_config = input_space.sample_configuration() # size=1) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config = input_space.sample_configuration() # size=1) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) -@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([ - (num_low_dims, special_param_values, max_unique_values_per_param) - for num_low_dims in (8, 16) - for special_param_values in ( - {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]}, - {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]}, - ) - for max_unique_values_per_param in (50, 250) -])) -def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: +@pytest.mark.parametrize( + ("num_low_dims", "special_param_values", "max_unique_values_per_param"), + ( + [ + (num_low_dims, special_param_values, max_unique_values_per_param) + for num_low_dims in (8, 16) + for special_param_values in ( + {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]}, + { + "int_1": (-1, 0.1), + "int_2": -1, + "int_3": (-1, 0.3), + "int_4": [(-1, 0.1), (0, 0.2)], + }, + ) + for max_unique_values_per_param in (50, 250) + ] + ), +) +def test_llamatune_pipeline( + num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int +) -> None: """ Tests LlamaTune space adapter when all components are active. """ # pylint: disable=too-many-locals # Define config space with a mix of different parameter types - input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5) + input_space = construct_parameter_space( + n_continuous_params=10, n_integer_params=10, n_categorical_params=5 + ) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=num_low_dims, @@ -419,23 +559,29 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u special_value_occurrences = { param: {special_value: 0 for special_value, _ in tuples_list} - for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access } unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()} num_configs = 1000 - for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable + for config in adapter.target_parameter_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # Transform low-dim config to high-dim point/config sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys())) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid - orig_config = CS.Configuration(input_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + input_space, values=orig_config_df.iloc[0].to_dict() + ) input_space.check_configuration(orig_config) # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == config for param, value in orig_config.items(): @@ -449,35 +595,66 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u # Ensure that occurrences of special values do not significantly deviate from expected eps = 0.2 - for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access + for ( + param, + tuples_list, + ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access for value, bias_percentage in tuples_list: - assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value] + assert (1 - eps) * int( + num_configs * bias_percentage + ) <= special_value_occurrences[param][value] # Ensure that number of unique values is less than the maximum number allowed for _, unique_values in unique_values_dict.items(): assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + { + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor + ) + }, + # Mix of all three types + { + "n_continuous_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_integer_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + "n_categorical_params": int( + num_target_space_dims * num_orig_space_factor / 3 + ), + }, + ) + ] + ), +) +def test_deterministic_behavior_for_same_seed( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: """ Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ + def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) @@ -490,8 +667,14 @@ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: use_approximate_reverse_mapping=False, ) - sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100) + sample_configs: List[CS.Configuration] = ( + adapter.target_parameter_space.sample_configuration(size=100) + ) return sample_configs - assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42) - assert generate_target_param_space_configs(1234) != generate_target_param_space_configs(42) + assert generate_target_param_space_configs( + 42 + ) == generate_target_param_space_configs(42) + assert generate_target_param_space_configs( + 1234 + ) != generate_target_param_space_configs(42) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 5390f97c5f..c2edd18b69 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -23,39 +23,51 @@ from mlos_core.tests import get_all_concrete_subclasses -@pytest.mark.parametrize(('space_adapter_type'), [ - # Enumerate all supported SpaceAdapters - # *[member for member in SpaceAdapterType], - *list(SpaceAdapterType), -]) +@pytest.mark.parametrize( + ("space_adapter_type"), + [ + # Enumerate all supported SpaceAdapters + # *[member for member in SpaceAdapterType], + *list(SpaceAdapterType), + ], +) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] + assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] -@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [ - # Default space adapter - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in SpaceAdapterType], -]) -def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("space_adapter_type", "kwargs"), + [ + # Default space adapter + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in SpaceAdapterType], + ], +) +def test_create_space_adapter_with_factory_method( + space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] +) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="x", lower=0, upper=1) + ) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="y", lower=0, upper=1) + ) # Adjust some kwargs for specific space adapters if space_adapter_type is SpaceAdapterType.LLAMATUNE: if kwargs is None: kwargs = {} - kwargs.setdefault('num_low_dims', 1) + kwargs.setdefault("num_low_dims", 1) space_adapter: BaseSpaceAdapter if space_adapter_type is None: @@ -73,21 +85,25 @@ def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[S assert space_adapter is not None assert space_adapter.orig_parameter_space is not None myrepr = repr(space_adapter) - assert myrepr.startswith(space_adapter_type.value.__name__), \ - f"Expected {space_adapter_type.value.__name__} but got {myrepr}" + assert myrepr.startswith( + space_adapter_type.value.__name__ + ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}" # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \ - get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract] +space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses( + BaseSpaceAdapter, pkg_name="mlos_core" +) # type: ignore[type-abstract] assert space_adapter_subclasses -@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) +@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: """ Test that all space adapter classes are listed in the SpaceAdapterType enum. """ - space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} + space_adapter_type_classes = { + space_adapter_type.value for space_adapter_type in SpaceAdapterType + } assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index dee9251652..20666df721 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -41,9 +41,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None: assert np.isclose(frequencies.sum(), 1) _f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies) - assert np.isclose(kurtosis, -1.2, atol=.1) - assert p_value > .3 - assert f_p_value > .5 + assert np.isclose(kurtosis, -1.2, atol=0.1) + assert p_value > 0.3 + assert f_p_value > 0.5 def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None: @@ -70,17 +70,20 @@ def invalid_conversion_function(*args: Any) -> NoReturn: """ A quick dummy function for the base class to make pylint happy. """ - raise NotImplementedError('subclass must override conversion_function') + raise NotImplementedError("subclass must override conversion_function") class BaseConversion(metaclass=ABCMeta): """ Base class for testing optimizer space conversions. """ + conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod - def sample(self, config_space: OptimizerSpace, n_samples: int = 1) -> OptimizerParam: + def sample( + self, config_space: OptimizerSpace, n_samples: int = 1 + ) -> OptimizerParam: """ Sample from the given configuration space. @@ -128,8 +131,12 @@ def test_unsupported_hyperparameter(self) -> None: def test_continuous_bounds(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformFloatHyperparameter("a", lower=100, upper=200)) - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("b", lower=-10, upper=-5)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter("a", lower=100, upper=200) + ) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("b", lower=-10, upper=-5) + ) converted_space = self.conversion_function(input_space) assert self.get_parameter_names(converted_space) == ["a", "b"] @@ -139,8 +146,12 @@ def test_continuous_bounds(self) -> None: def test_uniform_samples(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformFloatHyperparameter("a", lower=1, upper=5)) - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("c", lower=1, upper=20)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter("a", lower=1, upper=5) + ) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("c", lower=1, upper=20) + ) converted_space = self.conversion_function(input_space) np.random.seed(42) @@ -150,14 +161,16 @@ def test_uniform_samples(self) -> None: assert_is_uniform(uniform) # Check that we get both ends of the sampled range returned to us. - assert input_space['c'].lower in integer_uniform - assert input_space['c'].upper in integer_uniform + assert input_space["c"].lower in integer_uniform + assert input_space["c"].upper in integer_uniform # integer uniform assert_is_uniform(integer_uniform) def test_uniform_categorical(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"])) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter("c", choices=["foo", "bar"]) + ) converted_space = self.conversion_function(input_space) points = self.sample(converted_space, n_samples=100) counts = self.categorical_counts(points) @@ -165,13 +178,13 @@ def test_uniform_categorical(self) -> None: assert 35 < counts[1] < 65 def test_weighted_categorical(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_int_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_float_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") class TestFlamlConversion(BaseConversion): @@ -184,10 +197,12 @@ class TestFlamlConversion(BaseConversion): def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override] assert isinstance(config_space, dict) assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) - ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T + ret: npt.NDArray = np.array( + [domain.sample(size=n_samples) for domain in config_space.values()] + ).T return ret - def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] + def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] assert isinstance(config_space, dict) ret: List[str] = list(config_space.keys()) return ret @@ -199,16 +214,26 @@ def categorical_counts(self, points: npt.NDArray) -> npt.NDArray: def test_dimensionality(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("a", lower=1, upper=10)) - input_space.add_hyperparameter(CS.CategoricalHyperparameter("b", choices=["bof", "bum"])) - input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"])) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("a", lower=1, upper=10) + ) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter("b", choices=["bof", "bum"]) + ) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter("c", choices=["foo", "bar"]) + ) output_space = configspace_to_flaml_space(input_space) assert len(output_space) == 3 def test_weighted_categorical(self) -> None: np.random.seed(42) input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter( + "c", choices=["foo", "bar"], weights=[0.9, 0.1] + ) + ) with pytest.raises(ValueError, match="non-uniform"): configspace_to_flaml_space(input_space) @@ -217,7 +242,9 @@ def test_log_int_spaces(self) -> None: np.random.seed(42) # integer is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -235,7 +262,9 @@ def test_log_float_spaces(self) -> None: # continuous is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -245,6 +274,6 @@ def test_log_float_spaces(self) -> None: assert_is_log_uniform(float_log_uniform) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-k", "test_log_int_spaces", __file__]) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index df0e144535..e4533558a9 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -28,7 +28,9 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: return pd.DataFrame([dict(config)]) -def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration: +def normalize_config( + config_space: ConfigurationSpace, config: Union[Configuration, dict] +) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. @@ -47,10 +49,13 @@ def normalize_config(config_space: ConfigurationSpace, config: Union[Configurati cs_config: Configuration A valid ConfigSpace configuration with inactive parameters removed. """ - cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True) + cs_config = Configuration( + config_space, values=config, allow_inactive_with_values=True + ) return Configuration( - config_space, values={ + config_space, + values={ key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config) - } + }, ) diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index 2362de7083..f946f94aa4 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index fed376d1af..4a76b78020 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -21,21 +21,24 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") + warning( + f"setuptools_scm failed to find git version, using version from version.py: {e}" + ) # A simple routine to read and adjust the README.md for this module into a format @@ -49,53 +52,59 @@ # we return nothing when the file is not available. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } -extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0 -} +extra_requires: Dict[str, List[str]] = ( + { # pylint: disable=consider-using-namedtuple-or-dataclass + "flaml": ["flaml[blendsearch]"], + "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0 + } +) # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'scikit-learn>=1.2', - 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released - 'scipy>=1.3.2', - 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710 - 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', + "scikit-learn>=1.2", + "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released + "scipy>=1.3.2", + "numpy>=1.24", + "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 + 'pandas >= 2.2.0;python_version>="3.9"', + 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - 'ConfigSpace>=0.7.1', + "ConfigSpace>=0.7.1", ], extras_require=extra_requires, - **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"), + **_get_long_desc_from_readme( + "https://github.com/microsoft/MLOS/tree/main/mlos_core" + ), ) diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index 2390554e1e..b7a88957f3 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -23,7 +23,7 @@ class MlosVizMethod(Enum): """ DABL = "dabl" - AUTO = DABL # use dabl as the current default + AUTO = DABL # use dabl as the current default def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: @@ -39,17 +39,21 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - plotter_method: MlosVizMethod = MlosVizMethod.AUTO, - filter_warnings: bool = True, - **kwargs: Any) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + plotter_method: MlosVizMethod = MlosVizMethod.AUTO, + filter_warnings: bool = True, + **kwargs: Any, +) -> None: """ Plots the results of the experiment. @@ -77,10 +81,13 @@ def plot(exp_data: Optional[ExperimentData] = None, *, (results_df, _obj_cols) = expand_results_data_args(exp_data, results_df, objectives) base.plot_optimizer_trends(exp_data, results_df=results_df, objectives=objectives) - base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) + base.plot_top_n_configs( + exp_data, results_df=results_df, objectives=objectives, **kwargs + ) if MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 15358b0862..572759f816 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -20,7 +20,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args -_SEABORN_VERS = version('seaborn') +_SEABORN_VERS = version("seaborn") def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: @@ -30,7 +30,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: Note: this only works with non-positional kwargs (e.g., those after a * arg). """ target_kwargs = {} - for kword in target.__kwdefaults__: # or {} # intentionally omitted for now + for kword in target.__kwdefaults__: # or {} # intentionally omitted for now if kword in kwargs: target_kwargs[kword] = kwargs[kword] return target_kwargs @@ -42,14 +42,19 @@ def ignore_plotter_warnings() -> None: adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) - if _SEABORN_VERS <= '0.13.1': - warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas - message="is_categorical_dtype is deprecated and will be removed in a future version.") + if _SEABORN_VERS <= "0.13.1": + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="seaborn", # but actually comes from pandas + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) -def _add_groupby_desc_column(results_df: pandas.DataFrame, - groupby_columns: Optional[List[str]] = None, - ) -> Tuple[pandas.DataFrame, List[str], str]: +def _add_groupby_desc_column( + results_df: pandas.DataFrame, + groupby_columns: Optional[List[str]] = None, +) -> Tuple[pandas.DataFrame, List[str], str]: """ Adds a group descriptor column to the results_df. @@ -67,17 +72,19 @@ def _add_groupby_desc_column(results_df: pandas.DataFrame, if groupby_columns is None: groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"] groupby_column = ",".join(groupby_columns) - results_df[groupby_column] = results_df[groupby_columns].astype(str).apply( - lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda + results_df[groupby_column] = ( + results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1) + ) # pylint: disable=unnecessary-lambda groupby_columns.append(groupby_column) return (results_df, groupby_columns, groupby_column) -def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - requested_result_cols: Optional[Iterable[str]] = None, - ) -> pandas.DataFrame: +def augment_results_df_with_config_trial_group_stats( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + requested_result_cols: Optional[Iterable[str]] = None, +) -> pandas.DataFrame: # pylint: disable=too-complex """ Add a number of useful statistical measure columns to the results dataframe. @@ -134,30 +141,50 @@ def augment_results_df_with_config_trial_group_stats(exp_data: Optional[Experime raise ValueError(f"Not enough data: {len(results_groups)}") if requested_result_cols is None: - result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)) + result_cols = set( + col + for col in results_df.columns + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + ) else: - result_cols = set(col for col in requested_result_cols - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns) - result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols - if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns)) + result_cols = set( + col + for col in requested_result_cols + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + and col in results_df.columns + ) + result_cols.update( + set( + ExperimentData.RESULT_COLUMN_PREFIX + col + for col in requested_result_cols + if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns + ) + ) def compute_zscore_for_group_agg( - results_groups_perf: "SeriesGroupBy", - stats_df: pandas.DataFrame, - result_col: str, - agg: Union[Literal["mean"], Literal["var"], Literal["std"]] + results_groups_perf: "SeriesGroupBy", + stats_df: pandas.DataFrame, + result_col: str, + agg: Union[Literal["mean"], Literal["var"], Literal["std"]], ) -> None: - results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? + results_groups_perf_aggs = results_groups_perf.agg( + agg + ) # TODO: avoid recalculating? # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() - stats_df[result_col + f".{agg}_zscore"] = \ - (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \ - / stats_df[result_col + f".{agg}_stddev"] - stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True) + stats_df[result_col + f".{agg}_zscore"] = ( + stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"] + ) / stats_df[result_col + f".{agg}_stddev"] + stats_df.drop( + columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], + inplace=True, + ) augmented_results_df = results_df - augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count") + augmented_results_df["tunable_config_trial_group_size"] = results_groups[ + "trial_id" + ].transform("count") for result_col in result_cols: if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX): continue @@ -170,26 +197,31 @@ def compute_zscore_for_group_agg( continue results_groups_perf = results_groups[result_col] stats_df = pandas.DataFrame() - stats_df[result_col + ".mean"] = results_groups_perf.transform("mean", numeric_only=True) + stats_df[result_col + ".mean"] = results_groups_perf.transform( + "mean", numeric_only=True + ) stats_df[result_col + ".var"] = results_groups_perf.transform("var") - stats_df[result_col + ".stddev"] = stats_df[result_col + ".var"].apply(lambda x: x**0.5) + stats_df[result_col + ".stddev"] = stats_df[result_col + ".var"].apply( + lambda x: x**0.5 + ) compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var") quantiles = [0.50, 0.75, 0.90, 0.95, 0.99] - for quantile in quantiles: # TODO: can we do this in one pass? + for quantile in quantiles: # TODO: can we do this in one pass? quantile_col = f"{result_col}.p{int(quantile * 100)}" stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile) augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1) return augmented_results_df -def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - top_n_configs: int = 10, - method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", - ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: +def limit_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + top_n_configs: int = 10, + method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", +) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ Utility function to process the results and determine the best performing @@ -219,7 +251,9 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, raise ValueError(f"Invalid method: {method}") # Prepare the orderby columns. - (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives) + (results_df, objs_cols) = expand_results_data_args( + exp_data, results_df=results_df, objectives=objectives + ) assert isinstance(results_df, pandas.DataFrame) # Augment the results dataframe with some useful stats. @@ -232,13 +266,19 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # results_df is not None and is in fact a DataFrame, so we periodically assert # it in this func for now. assert results_df is not None - orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()} + orderby_cols: Dict[str, bool] = { + obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items() + } config_id_col = "tunable_config_id" - group_id_col = "tunable_config_trial_group_id" # first trial_id per config group + group_id_col = "tunable_config_trial_group_id" # first trial_id per config group trial_id_col = "trial_id" - default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id + default_config_id = ( + results_df[trial_id_col].min() + if exp_data is None + else exp_data.default_tunable_config_id + ) assert default_config_id is not None, "Failed to determine default config id." # Filter out configs whose variance is too large. @@ -250,16 +290,20 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, singletons_mask = results_df["tunable_config_trial_group_size"] == 1 else: singletons_mask = results_df["tunable_config_trial_group_size"] > 1 - results_df = results_df.loc[( - (results_df[f"{obj_col}.var_zscore"].abs() < 2) - | (singletons_mask) - | (results_df[config_id_col] == default_config_id) - )] + results_df = results_df.loc[ + ( + (results_df[f"{obj_col}.var_zscore"].abs() < 2) + | (singletons_mask) + | (results_df[config_id_col] == default_config_id) + ) + ] assert results_df is not None # Also, filter results that are worse than the default. - default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id] - for (orderby_col, ascending) in orderby_cols.items(): + default_config_results_df = results_df.loc[ + results_df[config_id_col] == default_config_id + ] + for orderby_col, ascending in orderby_cols.items(): default_vals = default_config_results_df[orderby_col].unique() assert len(default_vals) == 1 default_val = default_vals[0] @@ -271,29 +315,38 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # Now regroup and filter to the top-N configs by their group performance dimensions. assert results_df is not None - group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()] - top_n_config_ids: List[int] = group_results_df.sort_values( - by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist() + group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[ + orderby_cols.keys() + ] + top_n_config_ids: List[int] = ( + group_results_df.sort_values( + by=list(orderby_cols.keys()), ascending=list(orderby_cols.values()) + ) + .head(top_n_configs) + .index.tolist() + ) # Remove the default config if it's included. We'll add it back later. if default_config_id in top_n_config_ids: top_n_config_ids.remove(default_config_id) # Get just the top-n config results. # Sort by the group ids. - top_n_config_results_df = results_df.loc[( - results_df[config_id_col].isin(top_n_config_ids) - )].sort_values([group_id_col, config_id_col, trial_id_col]) + top_n_config_results_df = results_df.loc[ + (results_df[config_id_col].isin(top_n_config_ids)) + ].sort_values([group_id_col, config_id_col, trial_id_col]) # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) - top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0) + top_n_config_results_df = pandas.concat( + [default_config_results_df, top_n_config_results_df], axis=0 + ) return (top_n_config_results_df, top_n_config_ids, orderby_cols) def plot_optimizer_trends( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, ) -> None: """ Plots the optimizer trends for the Experiment. @@ -312,12 +365,16 @@ def plot_optimizer_trends( (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) (results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df) - for (objective_column, ascending) in obj_cols.items(): + for objective_column, ascending in obj_cols.items(): incumbent_column = objective_column + ".incumbent" # Determine the mean of each config trial group to match the box plots. - group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\ - .reset_index().sort_values(groupby_columns) + group_results_df = ( + results_df.groupby(groupby_columns)[objective_column] + .mean() + .reset_index() + .sort_values(groupby_columns) + ) # # Note: technically the optimizer (usually) uses the *first* result for a # given config trial group before moving on to a new config (x-axis), so @@ -331,9 +388,13 @@ def plot_optimizer_trends( # Calculate the incumbent (best seen so far) if ascending: - group_results_df[incumbent_column] = group_results_df[objective_column].cummin() + group_results_df[incumbent_column] = group_results_df[ + objective_column + ].cummin() else: - group_results_df[incumbent_column] = group_results_df[objective_column].cummax() + group_results_df[incumbent_column] = group_results_df[ + objective_column + ].cummax() (_fig, axis) = plt.subplots(figsize=(15, 5)) @@ -355,24 +416,29 @@ def plot_optimizer_trends( ax=axis, ) - plt.yscale('log') + plt.yscale("log") plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")) plt.xlabel("Config Trial Group ID, Config ID") plt.xticks(rotation=90, fontsize=8) - plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "") + plt.title( + "Optimizer Trends for Experiment: " + exp_data.experiment_id + if exp_data is not None + else "" + ) plt.grid() plt.show() # type: ignore[no-untyped-call] -def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - with_scatter_plot: bool = False, - **kwargs: Any, - ) -> None: +def plot_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + with_scatter_plot: bool = False, + **kwargs: Any, +) -> None: # pylint: disable=too-many-locals """ Plots the top-N configs along with the default config for the given ExperimentData. @@ -400,12 +466,16 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, top_n_config_args["results_df"] = results_df if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives - (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args) + (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( + exp_data=exp_data, **top_n_config_args + ) - (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df) + (top_n_config_results_df, _groupby_columns, groupby_column) = ( + _add_groupby_desc_column(top_n_config_results_df) + ) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 - for (orderby_col, ascending) in orderby_cols.items(): + for orderby_col, ascending in orderby_cols.items(): opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "") (_fig, axis) = plt.subplots() sns.violinplot( @@ -425,12 +495,12 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, plt.grid() (xticks, xlabels) = plt.xticks() # default should be in the first position based on top_n_configs() return - xlabels[0] = "default" # type: ignore[call-overload] - plt.xticks(xticks, xlabels) # type: ignore[arg-type] + xlabels[0] = "default" # type: ignore[call-overload] + plt.xticks(xticks, xlabels) # type: ignore[arg-type] plt.xlabel("Config Trial Group, Config ID") plt.xticks(rotation=90) plt.ylabel(opt_tgt) - plt.yscale('log') + plt.yscale("log") extra_title = "(lower is better)" if ascending else "(lower is better)" plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}") plt.show() # type: ignore[no-untyped-call] diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 504486a58c..9d7f673612 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -15,10 +15,12 @@ from mlos_viz.util import expand_results_data_args -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - ) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, +) -> None: """ Plots the Experiment results data using dabl. @@ -44,17 +46,51 @@ def ignore_plotter_warnings() -> None: """ # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, - message="Missing values in target_col have been removed for regression") + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="Could not infer format" + ) + warnings.filterwarnings( + "ignore", + module="dabl", + category=UserWarning, + message="(Dropped|Discarding) .* outliers", + ) + warnings.filterwarnings( + "ignore", + module="dabl", + category=UserWarning, + message="Not plotting highly correlated", + ) + warnings.filterwarnings( + "ignore", + module="dabl", + category=UserWarning, + message="Missing values in target_col have been removed for regression", + ) from sklearn.exceptions import UndefinedMetricWarning - warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined") - warnings.filterwarnings("ignore", category=DeprecationWarning, - message="is_categorical_dtype is deprecated and will be removed in a future version.") - warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn", - message="is_sparse is deprecated and will be removed in a future version.") + + warnings.filterwarnings( + "ignore", + module="sklearn", + category=UndefinedMetricWarning, + message="Recall is ill-defined", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="sklearn", + message="is_sparse is deprecated and will be removed in a future version.", + ) from matplotlib._api.deprecation import MatplotlibDeprecationWarning - warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl", - message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed") + + warnings.filterwarnings( + "ignore", + category=MatplotlibDeprecationWarning, + module="dabl", + message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", + ) diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index 06ac4a7664..e5528f9875 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -30,5 +30,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) -> warnings.simplefilter("error") random.seed(42) plot(exp_data, filter_warnings=True) - assert mock_show.call_count >= 2 # from the two base plots and anything dabl did - assert mock_boxplot.call_count >= 1 # from anything dabl did + assert mock_show.call_count >= 2 # from the two base plots and anything dabl did + assert mock_boxplot.call_count >= 1 # from anything dabl did diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 744fe28648..4e081193bc 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -41,24 +41,35 @@ def expand_results_data_args( # Prepare the orderby columns. if results_df is None: if exp_data is None: - raise ValueError("Must provide either exp_data or both results_df and objectives.") + raise ValueError( + "Must provide either exp_data or both results_df and objectives." + ) results_df = exp_data.results_df if objectives is None: if exp_data is None: - raise ValueError("Must provide either exp_data or both results_df and objectives.") + raise ValueError( + "Must provide either exp_data or both results_df and objectives." + ) objectives = exp_data.objectives objs_cols: Dict[str, bool] = {} - for (opt_tgt, opt_dir) in objectives.items(): + for opt_tgt, opt_dir in objectives.items(): if opt_dir not in ["min", "max"]: - raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}") + raise ValueError( + f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}" + ) ascending = opt_dir == "min" - if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns: + if ( + opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + and opt_tgt in results_df.columns + ): objs_cols[opt_tgt] = ascending elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns: objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending else: - raise UserWarning(f"{opt_tgt} is not a result column for experiment {exp_data}") + raise UserWarning( + f"{opt_tgt} is not a result column for experiment {exp_data}" + ) # Note: these copies are important to avoid issues with downstream consumers. # It is more efficient to copy the dataframe than to go back to the original data source. # TODO: However, it should be possible to later fixup the downstream consumers diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index 607c7cc014..d418ae43c7 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 98d12598e1..d8f6595813 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -21,21 +21,24 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") + warning( + f"setuptools_scm failed to find git version, using version from version.py: {e}" + ) # A simple routine to read and adjust the README.md for this module into a format @@ -47,22 +50,22 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } @@ -70,23 +73,25 @@ def _get_long_desc_from_readme(base_url: str) -> dict: # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'mlos-bench==' + VERSION, - 'dabl>=0.2.6', - 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341 + "mlos-bench==" + VERSION, + "dabl>=0.2.6", + "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341 ], extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'), + **_get_long_desc_from_readme( + "https://github.com/microsoft/MLOS/tree/main/mlos_viz" + ), ) From ae3bd626d11cb346021257a187af57cf39deb69f Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:52:33 +0000 Subject: [PATCH 07/54] incompatible rules between flake and black and pycodestyle --- .../tests/services/remote/azure/azure_fileshare_test.py | 2 +- setup.cfg | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 2858b2388c..64633a534b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -215,7 +215,7 @@ def process_paths(input_path: str) -> str: skip_prefix = os.getcwd() # Remove prefix from os.path.abspath if there if input_path == os.path.abspath(input_path): - result = input_path[len(skip_prefix) + 1 :] + result = input_path[(len(skip_prefix) + 1) :] else: result = input_path # Change file seps to unix-style diff --git a/setup.cfg b/setup.cfg index 661a7971cd..b1cf391742 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,9 +2,10 @@ [pycodestyle] count = True +# E203: Whitespace before : (black incompatibility) # W503: Line break occurred before a binary operator # W504: Line break occurred after a binary operator -ignore = W503,W504 +ignore = E203,W503,W504 format = pylint # See Also: .editorconfig, .pylintrc max-line-length = 88 @@ -26,6 +27,8 @@ convention = numpy [flake8] max-line-length = 88 +# black incompatibility +extend-ignore = E203 [tool:pytest] minversion = 7.1 From 8af6eaced40512f99e622c86c6eec389b7aab2ab Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 3 Jul 2024 21:54:56 +0000 Subject: [PATCH 08/54] enable string quote consistency checks --- .pylintrc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pylintrc b/.pylintrc index fdc93e2956..6b873c5d60 100644 --- a/.pylintrc +++ b/.pylintrc @@ -48,5 +48,5 @@ disable= missing-raises-doc [STRING] -#check-quote-consistency=yes +check-quote-consistency=yes check-str-concat-over-line-jumps=yes From a6d3214f6bee0c61e909ec44a5bea0635ef4207a Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:14:51 +0000 Subject: [PATCH 09/54] Revert "black formatting" This reverts commit 1b9843a587b77daab06df362a0b109b2712b91a2. --- .../fio/scripts/local/process_fio_results.py | 26 +- .../scripts/local/generate_redis_config.py | 12 +- .../scripts/local/process_redis_results.py | 30 +- .../boot/scripts/local/create_new_grub_cfg.py | 11 +- .../scripts/local/generate_grub_config.py | 10 +- .../local/generate_kernel_config_script.py | 5 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 21 +- mlos_bench/mlos_bench/dict_templater.py | 21 +- .../mlos_bench/environments/__init__.py | 15 +- .../environments/base_environment.py | 118 ++-- .../mlos_bench/environments/composite_env.py | 73 +-- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 139 ++--- .../environments/local/local_fileshare_env.py | 68 +-- .../mlos_bench/environments/mock_env.py | 52 +- .../environments/remote/__init__.py | 12 +- .../environments/remote/host_env.py | 37 +- .../environments/remote/network_env.py | 45 +- .../mlos_bench/environments/remote/os_env.py | 40 +- .../environments/remote/remote_env.py | 46 +- .../environments/remote/saas_env.py | 46 +- .../mlos_bench/environments/script_env.py | 45 +- mlos_bench/mlos_bench/event_loop_context.py | 14 +- mlos_bench/mlos_bench/launcher.py | 288 ++++------ mlos_bench/mlos_bench/optimizers/__init__.py | 8 +- .../mlos_bench/optimizers/base_optimizer.py | 97 ++-- .../optimizers/convert_configspace.py | 122 ++--- .../optimizers/grid_search_optimizer.py | 91 +--- .../optimizers/mlos_core_optimizer.py | 100 ++-- .../mlos_bench/optimizers/mock_optimizer.py | 26 +- .../optimizers/one_shot_optimizer.py | 12 +- .../optimizers/track_best_optimizer.py | 26 +- mlos_bench/mlos_bench/os_environ.py | 11 +- mlos_bench/mlos_bench/run.py | 9 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 113 ++-- .../mlos_bench/schedulers/sync_scheduler.py | 4 +- mlos_bench/mlos_bench/services/__init__.py | 6 +- .../mlos_bench/services/base_fileshare.py | 43 +- .../mlos_bench/services/base_service.py | 68 +-- .../mlos_bench/services/config_persistence.py | 309 ++++------- .../mlos_bench/services/local/__init__.py | 2 +- .../mlos_bench/services/local/local_exec.py | 59 +- .../services/local/temp_dir_context.py | 26 +- .../services/remote/azure/__init__.py | 10 +- .../services/remote/azure/azure_auth.py | 52 +- .../remote/azure/azure_deployment_services.py | 175 ++---- .../services/remote/azure/azure_fileshare.py | 35 +- .../remote/azure/azure_network_services.py | 81 ++- .../services/remote/azure/azure_saas.py | 133 ++--- .../remote/azure/azure_vm_services.py | 269 ++++------ .../services/remote/ssh/ssh_fileshare.py | 65 +-- .../services/remote/ssh/ssh_host_service.py | 109 ++-- .../services/remote/ssh/ssh_service.py | 185 ++----- .../mlos_bench/services/types/__init__.py | 16 +- .../services/types/config_loader_type.py | 45 +- .../services/types/fileshare_type.py | 8 +- .../services/types/host_provisioner_type.py | 4 +- .../services/types/local_exec_type.py | 13 +- .../types/network_provisioner_type.py | 8 +- .../services/types/remote_config_type.py | 5 +- .../services/types/remote_exec_type.py | 5 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 19 +- mlos_bench/mlos_bench/storage/base_storage.py | 131 ++--- .../mlos_bench/storage/base_trial_data.py | 26 +- .../storage/base_tunable_config_data.py | 3 +- .../base_tunable_config_trial_group_data.py | 24 +- mlos_bench/mlos_bench/storage/sql/__init__.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 233 +++----- .../mlos_bench/storage/sql/experiment.py | 281 ++++------ .../mlos_bench/storage/sql/experiment_data.py | 105 ++-- mlos_bench/mlos_bench/storage/sql/schema.py | 52 +- mlos_bench/mlos_bench/storage/sql/storage.py | 29 +- mlos_bench/mlos_bench/storage/sql/trial.py | 146 ++--- .../mlos_bench/storage/sql/trial_data.py | 80 +-- .../storage/sql/tunable_config_data.py | 14 +- .../sql/tunable_config_trial_group_data.py | 43 +- .../mlos_bench/storage/storage_factory.py | 8 +- mlos_bench/mlos_bench/storage/util.py | 24 +- mlos_bench/mlos_bench/tests/__init__.py | 40 +- .../mlos_bench/tests/config/__init__.py | 12 +- .../cli/test_load_cli_config_examples.py | 89 +-- .../mlos_bench/tests/config/conftest.py | 14 +- .../test_load_environment_config_examples.py | 72 +-- .../test_load_global_config_examples.py | 8 +- .../test_load_optimizer_config_examples.py | 8 +- .../tests/config/schemas/__init__.py | 76 +-- .../config/schemas/cli/test_cli_schemas.py | 13 +- .../environments/test_environment_schemas.py | 42 +- .../schemas/globals/test_globals_schemas.py | 9 +- .../optimizers/test_optimizer_schemas.py | 89 +-- .../schedulers/test_scheduler_schemas.py | 42 +- .../schemas/services/test_services_schemas.py | 43 +- .../schemas/storage/test_storage_schemas.py | 52 +- .../test_tunable_params_schemas.py | 9 +- .../test_tunable_values_schemas.py | 9 +- .../test_load_service_config_examples.py | 14 +- .../test_load_storage_config_examples.py | 8 +- mlos_bench/mlos_bench/tests/conftest.py | 24 +- .../mlos_bench/tests/dict_templater_test.py | 4 +- .../mlos_bench/tests/environments/__init__.py | 14 +- .../tests/environments/base_env_test.py | 4 +- .../composite_env_service_test.py | 36 +- .../tests/environments/composite_env_test.py | 160 +++--- .../environments/include_tunables_test.py | 36 +- .../tests/environments/local/__init__.py | 20 +- .../local/composite_local_env_test.py | 23 +- .../local/local_env_stdout_test.py | 88 ++- .../local/local_env_telemetry_test.py | 149 +++--- .../environments/local/local_env_test.py | 73 ++- .../environments/local/local_env_vars_test.py | 61 +-- .../local/local_fileshare_env_test.py | 25 +- .../tests/environments/mock_env_test.py | 87 ++- .../tests/environments/remote/test_ssh_env.py | 18 +- .../tests/event_loop_context_test.py | 69 +-- .../tests/launcher_in_process_test.py | 40 +- .../tests/launcher_parse_args_test.py | 147 +++-- .../mlos_bench/tests/launcher_run_test.py | 107 ++-- .../mlos_bench/tests/optimizers/conftest.py | 40 +- .../optimizers/grid_search_optimizer_test.py | 144 ++--- .../tests/optimizers/llamatune_opt_test.py | 9 +- .../tests/optimizers/mlos_core_opt_df_test.py | 68 +-- .../optimizers/mlos_core_opt_smac_test.py | 96 ++-- .../tests/optimizers/mock_opt_test.py | 71 +-- .../optimizers/opt_bulk_register_test.py | 101 ++-- .../optimizers/toy_optimization_loop_test.py | 30 +- .../mlos_bench/tests/services/__init__.py | 8 +- .../tests/services/config_persistence_test.py | 50 +- .../tests/services/local/__init__.py | 2 +- .../services/local/local_exec_python_test.py | 15 +- .../tests/services/local/local_exec_test.py | 126 ++--- .../tests/services/local/mock/__init__.py | 2 +- .../local/mock/mock_local_exec_service.py | 26 +- .../mlos_bench/tests/services/mock_service.py | 23 +- .../tests/services/remote/__init__.py | 6 +- .../remote/azure/azure_fileshare_test.py | 164 ++---- .../azure/azure_network_services_test.py | 99 ++-- .../remote/azure/azure_vm_services_test.py | 231 +++----- .../tests/services/remote/azure/conftest.py | 108 ++-- .../services/remote/mock/mock_auth_service.py | 26 +- .../remote/mock/mock_fileshare_service.py | 25 +- .../remote/mock/mock_network_service.py | 35 +- .../remote/mock/mock_remote_exec_service.py | 26 +- .../services/remote/mock/mock_vm_service.py | 55 +- .../tests/services/remote/ssh/__init__.py | 18 +- .../tests/services/remote/ssh/fixtures.py | 67 +-- .../services/remote/ssh/test_ssh_fileshare.py | 48 +- .../remote/ssh/test_ssh_host_service.py | 102 ++-- .../services/remote/ssh/test_ssh_service.py | 65 +-- .../test_service_method_registering.py | 10 +- .../mlos_bench/tests/storage/conftest.py | 8 +- .../mlos_bench/tests/storage/exp_data_test.py | 82 +-- .../mlos_bench/tests/storage/exp_load_test.py | 80 ++- .../mlos_bench/tests/storage/sql/fixtures.py | 97 ++-- .../tests/storage/trial_config_test.py | 19 +- .../tests/storage/trial_schedule_test.py | 36 +- .../tests/storage/trial_telemetry_test.py | 49 +- .../tests/storage/tunable_config_data_test.py | 26 +- .../tunable_config_trial_group_data_test.py | 68 +-- .../mlos_bench/tests/test_with_alt_tz.py | 8 +- .../tests/tunable_groups_fixtures.py | 38 +- .../mlos_bench/tests/tunables/conftest.py | 47 +- .../tunables/test_tunable_categoricals.py | 2 +- .../tunables/test_tunables_size_props.py | 23 +- .../tests/tunables/tunable_comparison_test.py | 15 +- .../tests/tunables/tunable_definition_test.py | 108 ++-- .../tunables/tunable_distributions_test.py | 68 +-- .../tunables/tunable_group_indexing_test.py | 12 +- .../tunables/tunable_group_subgroup_test.py | 2 +- .../tunable_to_configspace_distr_test.py | 54 +- .../tunables/tunable_to_configspace_test.py | 59 +- .../tests/tunables/tunables_assign_test.py | 26 +- .../tests/tunables/tunables_str_test.py | 76 ++- mlos_bench/mlos_bench/tunables/__init__.py | 6 +- .../mlos_bench/tunables/covariant_group.py | 18 +- mlos_bench/mlos_bench/tunables/tunable.py | 124 ++--- .../mlos_bench/tunables/tunable_groups.py | 76 +-- mlos_bench/mlos_bench/util.py | 59 +- mlos_bench/mlos_bench/version.py | 2 +- mlos_bench/setup.py | 95 ++-- mlos_core/mlos_core/optimizers/__init__.py | 32 +- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 +- .../bayesian_optimizers/smac_optimizer.py | 179 ++----- .../mlos_core/optimizers/flaml_optimizer.py | 78 +-- mlos_core/mlos_core/optimizers/optimizer.py | 176 ++---- .../mlos_core/optimizers/random_optimizer.py | 47 +- .../mlos_core/spaces/adapters/__init__.py | 19 +- .../mlos_core/spaces/adapters/adapter.py | 10 +- .../mlos_core/spaces/adapters/llamatune.py | 209 +++----- .../mlos_core/spaces/converters/flaml.py | 26 +- mlos_core/mlos_core/tests/__init__.py | 23 +- .../optimizers/bayesian_optimizers_test.py | 23 +- .../mlos_core/tests/optimizers/conftest.py | 10 +- .../tests/optimizers/one_hot_test.py | 96 ++-- .../optimizers/optimizer_multiobj_test.py | 80 ++- .../tests/optimizers/optimizer_test.py | 241 ++++----- .../spaces/adapters/identity_adapter_test.py | 25 +- .../tests/spaces/adapters/llamatune_test.py | 505 ++++++------------ .../adapters/space_adapter_factory_test.py | 60 +-- .../mlos_core/tests/spaces/spaces_test.py | 77 +-- mlos_core/mlos_core/util.py | 13 +- mlos_core/mlos_core/version.py | 2 +- mlos_core/setup.py | 71 ++- mlos_viz/mlos_viz/__init__.py | 23 +- mlos_viz/mlos_viz/base.py | 234 +++----- mlos_viz/mlos_viz/dabl.py | 68 +-- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 23 +- mlos_viz/mlos_viz/version.py | 2 +- mlos_viz/setup.py | 53 +- 213 files changed, 4388 insertions(+), 7912 deletions(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index 75c72e6207..c32dea9bf6 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -20,7 +20,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: Flatten every dict in the hierarchy and rename the keys with the dict path. """ if isinstance(data, dict): - for key, val in data.items(): + for (key, val) in data.items(): yield from _flat_dict(val, f"{path}.{key}") else: yield (path, data) @@ -30,15 +30,13 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: """ Convert FIO read data from JSON to tall CSV. """ - with open(input_file, mode="r", encoding="utf-8") as fh_input: + with open(input_file, mode='r', encoding='utf-8') as fh_input: json_data = json.load(fh_input) - data = list( - itertools.chain( - _flat_dict(json_data["jobs"][0], prefix), - _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"), - ) - ) + data = list(itertools.chain( + _flat_dict(json_data["jobs"][0], prefix), + _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util") + )) tall_df = pandas.DataFrame(data, columns=["metric", "value"]) tall_df.to_csv(output_file, index=False) @@ -51,16 +49,12 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.") parser.add_argument( - "input", - help="FIO benchmark results in JSON format (downloaded from a remote VM).", - ) + "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).") parser.add_argument( - "output", - help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).", - ) + "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).") parser.add_argument( - "--prefix", default="fio", help="Prefix of the metric IDs (default 'fio')" - ) + "--prefix", default="fio", + help="Prefix of the metric IDs (default 'fio')") args = parser.parse_args() _main(args.input, args.output, args.prefix) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index d41f20d2a9..949b9f9d91 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -14,19 +14,17 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( - fname_output, "wt", encoding="utf-8", newline="" - ) as fh_config: - for key, val in json.load(fh_tunables).items(): - line = f"{key} {val}" + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ + open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: + for (key, val) in json.load(fh_tunables).items(): + line = f'{key} {val}' fh_config.write(line + "\n") print(line) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate Redis config from tunable parameters JSON." - ) + description="generate Redis config from tunable parameters JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output Redis config file.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index eb0b904c5d..e33c717953 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -21,21 +21,18 @@ def _main(input_file: str, output_file: str) -> None: # Format the results from wide to long # The target is columns of metric and value to act as key-value pairs. df_long = ( - df_wide.melt(id_vars=["test"]) + df_wide + .melt(id_vars=["test"]) .assign(metric=lambda df: df["test"] + "_" + df["variable"]) .drop(columns=["test", "variable"]) .loc[:, ["metric", "value"]] ) # Add a default `score` metric to the end of the dataframe. - df_long = pd.concat( - [ - df_long, - pd.DataFrame( - {"metric": ["score"], "value": [df_long.value[df_long.index.max()]]} - ), - ] - ) + df_long = pd.concat([ + df_long, + pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}) + ]) df_long.to_csv(output_file, index=False) print(f"Converted: {input_file} -> {output_file}") @@ -43,16 +40,9 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Post-process Redis benchmark results." - ) - parser.add_argument( - "input", help="Redis benchmark results (downloaded from a remote VM)." - ) - parser.add_argument( - "output", - help="Converted Redis benchmark data" - + " (to be consumed by OS Autotune framework).", - ) + parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") + parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") + parser.add_argument("output", help="Converted Redis benchmark data" + + " (to be consumed by OS Autotune framework).") args = parser.parse_args() _main(args.input, args.output) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 649d537558..41bd162459 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -14,11 +14,8 @@ JSON_CONFIG_FILE = "config-boot-time.json" NEW_CFG = "zz-mlos-boot-params.cfg" -with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open( - NEW_CFG, "w", encoding="UTF-8" -) as fh_config: +with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \ + open(NEW_CFG, 'w', encoding='UTF-8') as fh_config: for key, val in json.load(fh_json).items(): - fh_config.write( - 'GRUB_CMDLINE_LINUX_DEFAULT="$' - f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n' - ) + fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$' + f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n') diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index 9f130e5c0e..de344d61fb 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -14,10 +14,9 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( - fname_output, "wt", encoding="utf-8", newline="" - ) as fh_config: - for key, val in json.load(fh_tunables).items(): + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ + open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: + for (key, val) in json.load(fh_tunables).items(): line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"' fh_config.write(line + "\n") print(line) @@ -25,8 +24,7 @@ def _main(fname_input: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate GRUB config from tunable parameters JSON." - ) + description="Generate GRUB config from tunable parameters JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output shell script to configure GRUB.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index e632495061..85a49a1817 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -22,7 +22,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: tunables_meta = json.load(fh_meta) with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for key, val in tunables_data.items(): + for (key, val) in tunables_data.items(): meta = tunables_meta.get(key, {}) name_prefix = meta.get("name_prefix", "") line = f'echo "{val}" > {name_prefix}{key}' @@ -33,8 +33,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate a script to update kernel parameters from tunables JSON." - ) + description="generate a script to update kernel parameters from tunables JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("meta", help="JSON file with tunable parameters metadata.") diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index 672a215aad..fa3b63e2e6 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -9,6 +9,6 @@ from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ - "ConfigSchema", - "CONFIG_SCHEMA_DIR", + 'ConfigSchema', + 'CONFIG_SCHEMA_DIR', ] diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 181f96e5d6..82cbcacce2 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -27,14 +27,9 @@ # It is used in `ConfigSchema.validate()` method below. # NOTE: this may cause pytest to fail if it's expecting exceptions # to be raised for invalid configs. -_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION" -_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in { - "true", - "y", - "yes", - "on", - "1", -} +_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION' +_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower() + in {'true', 'y', 'yes', 'on', '1'}) # Note: we separate out the SchemaStore from a class method on ConfigSchema @@ -85,12 +80,10 @@ def _load_registry(cls) -> None: """Also store them in a Registry object for referencing by recent versions of jsonschema.""" if not cls._SCHEMA_STORE: cls._load_schemas() - cls._REGISTRY = Registry().with_resources( - [ - (url, Resource.from_contents(schema, default_specification=DRAFT202012)) - for url, schema in cls._SCHEMA_STORE.items() - ] - ) + cls._REGISTRY = Registry().with_resources([ + (url, Resource.from_contents(schema, default_specification=DRAFT202012)) + for url, schema in cls._SCHEMA_STORE.items() + ]) @property def registry(self) -> Registry: diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 26c573b4c0..4ccef7817b 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -13,7 +13,7 @@ from mlos_bench.os_environ import environ -class DictTemplater: # pylint: disable=too-few-public-methods +class DictTemplater: # pylint: disable=too-few-public-methods """ Simple class to help with nested dictionary $var templating. """ @@ -32,12 +32,9 @@ def __init__(self, source_dict: Dict[str, Any]): # The source/target dictionary to expand. self._dict: Dict[str, Any] = {} - def expand_vars( - self, - *, - extra_source_dict: Optional[Dict[str, Any]] = None, - use_os_env: bool = False, - ) -> Dict[str, Any]: + def expand_vars(self, *, + extra_source_dict: Optional[Dict[str, Any]] = None, + use_os_env: bool = False) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. @@ -58,9 +55,7 @@ def expand_vars( assert isinstance(self._dict, dict) return self._dict - def _expand_vars( - self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool - ) -> Any: + def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: """ Recursively expand $var strings in the currently operating dictionary. """ @@ -76,12 +71,10 @@ def _expand_vars( elif isinstance(value, dict): # Note: we use a loop instead of dict comprehension in order to # allow secondary expansion of subsequent values immediately. - for key, val in value.items(): + for (key, val) in value.items(): value[key] = self._expand_vars(val, extra_source_dict, use_os_env) elif isinstance(value, list): - value = [ - self._expand_vars(val, extra_source_dict, use_os_env) for val in value - ] + value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] elif isinstance(value, (int, float, bool)) or value is None: return value else: diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index 629e7d9c5f..a1ccadae5f 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -15,11 +15,12 @@ from mlos_bench.environments.status import Status __all__ = [ - "Status", - "Environment", - "MockEnv", - "RemoteEnv", - "LocalEnv", - "LocalFileShareEnv", - "CompositeEnv", + 'Status', + + 'Environment', + 'MockEnv', + 'RemoteEnv', + 'LocalEnv', + 'LocalFileShareEnv', + 'CompositeEnv', ] diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index d358f903be..61fbd69f50 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -48,16 +48,15 @@ class Environment(metaclass=abc.ABCMeta): """ @classmethod - def new( - cls, - *, - env_name: str, - class_name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ) -> "Environment": + def new(cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -95,18 +94,16 @@ def new( config=config, global_config=global_config, tunables=tunables, - service=service, + service=service ) - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -137,41 +134,34 @@ def __init__( self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Environment: '%s' Service: %s", - name, - self._service.pprint() if self._service else None, - ) + _LOG.debug("Environment: '%s' Service: %s", name, + self._service.pprint() if self._service else None) if tunables is None: - _LOG.warning( - "No tunables provided for %s. Tunable inheritance across composite environments may be broken.", - name, - ) + _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) tunables = TunableGroups() groups = self._expand_groups( config.get("tunable_params", []), - (global_config or {}).get("tunable_params_map", {}), - ) + (global_config or {}).get("tunable_params_map", {})) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) self._tunable_params = tunables.subgroup(groups) # If a parameter comes from the tunables, do not require it in the const_args or globals - req_args = set(config.get("required_args", [])) - set( - self._tunable_params.get_param_values().keys() - ) - merge_parameters( - dest=self._const_args, source=global_config, required_keys=req_args + req_args = ( + set(config.get("required_args", [])) - + set(self._tunable_params.get_param_values().keys()) ) + merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) self._const_args = self._expand_vars(self._const_args, global_config or {}) self._params = self._combine_tunables(self._tunable_params) _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) + _LOG.debug("Config for: '%s'\n%s", + name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: """ @@ -189,9 +179,8 @@ def _validate_json_config(self, config: dict, name: str) -> None: ConfigSchema.ENVIRONMENT.validate(json_config) @staticmethod - def _expand_groups( - groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] - ) -> List[str]: + def _expand_groups(groups: Iterable[str], + groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -213,9 +202,7 @@ def _expand_groups( if grp[:1] == "$": tunable_group_name = grp[1:] if tunable_group_name not in groups_exp: - raise KeyError( - f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}" - ) + raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") add_groups = groups_exp[tunable_group_name] res += [add_groups] if isinstance(add_groups, str) else add_groups else: @@ -223,9 +210,7 @@ def _expand_groups( return res @staticmethod - def _expand_vars( - params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] - ) -> dict: + def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: """ Expand `$var` into actual values of the variables. """ @@ -236,7 +221,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": assert self._service is not None return self._service.config_loader_service - def __enter__(self) -> "Environment": + def __enter__(self) -> 'Environment': """ Enter the environment's benchmarking context. """ @@ -247,12 +232,9 @@ def __enter__(self) -> "Environment": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -261,20 +243,14 @@ def __exit__( _LOG.debug("Environment END :: %s", self) else: assert ex_type and ex_val - _LOG.warning( - "Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb) - ) + _LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb)) assert self._in_context if self._service_context: try: self._service_context.__exit__(ex_type, ex_val, ex_tb) # pylint: disable=broad-exception-caught except Exception as ex: - _LOG.error( - "Exception while exiting Service context '%s': %s", - self._service, - ex, - ) + _LOG.error("Exception while exiting Service context '%s': %s", self._service, ex) ex_throw = ex finally: self._service_context = None @@ -328,8 +304,7 @@ def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ return tunables.get_param_values( group_names=list(self._tunable_params.get_covariant_group_names()), - into_params=self._const_args.copy(), - ) + into_params=self._const_args.copy()) @property def tunable_params(self) -> TunableGroups: @@ -356,9 +331,7 @@ def parameters(self) -> Dict[str, TunableValue]: """ return self._params - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Set up a new benchmark environment, if necessary. This method must be idempotent, i.e., calling it several times in a row should be @@ -391,15 +364,10 @@ def setup( # (Derived classes still have to check `self._tunable_params.is_updated()`). is_updated = self._tunable_params.is_updated() if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Env '%s': Tunable groups reset = %s :: %s", - self, - is_updated, - { - name: self._tunable_params.is_updated([name]) - for name in self._tunable_params.get_covariant_group_names() - }, - ) + _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { + name: self._tunable_params.is_updated([name]) + for name in self._tunable_params.get_covariant_group_names() + }) else: _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 4b5e2755cf..a71b8ab9be 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -27,15 +27,13 @@ class CompositeEnv(Environment): Composite benchmark environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -55,13 +53,8 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) # By default, the Environment includes only the tunables explicitly specified # in the "tunable_params" section of the config. `CompositeEnv`, however, must @@ -77,28 +70,20 @@ def __init__( # each CompositeEnv gets a copy of the original global config and adjusts it with # the `const_args` specific to it. global_config = (global_config or {}).copy() - for key, val in self._const_args.items(): + for (key, val) in self._const_args.items(): global_config.setdefault(key, val) for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, - tunables, - global_config, - self._const_args, - self._service, - ): + child_config_file, tunables, global_config, self._const_args, self._service): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service - ) + child_config, tunables, global_config, self._const_args, self._service) self._add_child(env, tunables) - _LOG.debug( - "Build composite environment '%s' END: %s", self, self._tunable_params - ) + _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) if not self._children: raise ValueError("At least one child environment must be present") @@ -107,21 +92,16 @@ def __enter__(self) -> Environment: self._child_contexts = [env.__enter__() for env in self._children] return super().__enter__() - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: ex_throw = None for env in reversed(self._children): try: env.__exit__(ex_type, ex_val, ex_tb) # pylint: disable=broad-exception-caught except Exception as ex: - _LOG.error( - "Exception while exiting child environment '%s': %s", env, ex - ) + _LOG.error("Exception while exiting child environment '%s': %s", env, ex) ex_throw = ex self._child_contexts = [] super().__exit__(ex_type, ex_val, ex_tb) @@ -152,11 +132,8 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: pretty : str Pretty-printed environment configuration. """ - return ( - super().pprint(indent, level) - + "\n" - + "\n".join(child.pprint(indent, level + 1) for child in self._children) - ) + return super().pprint(indent, level) + '\n' + '\n'.join( + child.pprint(indent, level + 1) for child in self._children) def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ @@ -168,9 +145,7 @@ def _add_child(self, env: Environment, tunables: TunableGroups) -> None: self._tunable_params.merge(env.tunable_params) tunables.merge(env.tunable_params) - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Set up the children environments. @@ -190,9 +165,7 @@ def setup( """ assert self._in_context self._is_ready = super().setup(tunables, global_config) and all( - env_context.setup(tunables, global_config) - for env_context in self._child_contexts - ) + env_context.setup(tunables, global_config) for env_context in self._child_contexts) return self._is_ready def teardown(self) -> None: @@ -229,9 +202,7 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: for env_context in self._child_contexts: _LOG.debug("Child env. run: %s", env_context) (status, timestamp, metrics) = env_context.run() - _LOG.debug( - "Child env. run results: %s :: %s %s", env_context, status, metrics - ) + _LOG.debug("Child env. run results: %s :: %s %s", env_context, status, metrics) if not status.is_good(): _LOG.info("Run failed: %s :: %s", self, status) return (status, timestamp, None) diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index a99eefea19..0cdd8349b4 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv __all__ = [ - "LocalEnv", - "LocalFileShareEnv", + 'LocalEnv', + 'LocalFileShareEnv', ] diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index a78898d90b..da20f5c961 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -36,15 +36,13 @@ class LocalEnv(ScriptEnv): Scheduler-side Environment that runs scripts locally. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for local execution. @@ -67,17 +65,11 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsLocalExec - ), "LocalEnv requires a service that supports local execution" + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service self._temp_dir: Optional[str] = None @@ -87,24 +79,17 @@ def __init__( self._dump_meta_file: Optional[str] = self.config.get("dump_meta_file") self._read_results_file: Optional[str] = self.config.get("read_results_file") - self._read_telemetry_file: Optional[str] = self.config.get( - "read_telemetry_file" - ) + self._read_telemetry_file: Optional[str] = self.config.get("read_telemetry_file") def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None - self._temp_dir_context = self._local_exec_service.temp_dir_context( - self.config.get("temp_dir") - ) + self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir")) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -114,9 +99,7 @@ def __exit__( self._temp_dir_context = None return super().__exit__(ex_type, ex_val, ex_tb) - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the environment is ready and set up the application and benchmarks, if necessary. @@ -154,19 +137,13 @@ def setup( fname = path_join(self._temp_dir, self._dump_meta_file) _LOG.debug("Dump tunables metadata to file: %s", fname) with open(fname, "wt", encoding="utf-8") as fh_meta: - json.dump( - { - tunable.name: tunable.meta - for (tunable, _group) in self._tunable_params - if tunable.meta - }, - fh_meta, - ) + json.dump({ + tunable.name: tunable.meta + for (tunable, _group) in self._tunable_params if tunable.meta + }, fh_meta) if self._script_setup: - (return_code, _output) = self._local_exec( - self._script_setup, self._temp_dir - ) + (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir) self._is_ready = bool(return_code == 0) else: self._is_ready = True @@ -203,26 +180,18 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: _LOG.debug("Not reading the data at: %s", self) return (Status.SUCCEEDED, timestamp, stdout_data) - data = self._normalize_columns( - pandas.read_csv( - self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir] - ), - index_col=False, - ) - ) + data = self._normalize_columns(pandas.read_csv( + self._config_loader_service.resolve_path( + self._read_results_file, extra_paths=[self._temp_dir]), + index_col=False, + )) _LOG.debug("Read data:\n%s", data) if list(data.columns) == ["metric", "value"]: - _LOG.info( - "Local results have (metric,value) header and %d rows: assume long format", - len(data), - ) - data = pandas.DataFrame( - [data.value.to_list()], columns=data.metric.to_list() - ) + _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data)) + data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) # Try to convert string metrics to numbers. - data = data.apply(pandas.to_numeric, errors="coerce").fillna(data) # type: ignore[assignment] # (false positive) + data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive) elif len(data) == 1: _LOG.info("Local results have 1 row: assume wide format") else: @@ -240,8 +209,8 @@ def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. - if sys.platform == "win32": - data.rename(str.rstrip, axis="columns", inplace=True) + if sys.platform == 'win32': + data.rename(str.rstrip, axis='columns', inplace=True) return data def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: @@ -253,45 +222,36 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir] - ) + self._read_telemetry_file, extra_paths=[self._temp_dir]) # TODO: Use the timestamp of the CSV file as our status timestamp? # FIXME: We should not be assuming that the only output file type is a CSV. - data = self._normalize_columns(pandas.read_csv(fname, index_col=False)) + data = self._normalize_columns( + pandas.read_csv(fname, index_col=False)) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") expected_col_names = ["timestamp", "metric", "value"] if len(data.columns) != len(expected_col_names): - raise ValueError( - f"Telemetry data must have columns {expected_col_names}" - ) + raise ValueError(f'Telemetry data must have columns {expected_col_names}') if list(data.columns) != expected_col_names: # Assume no header - this is ok for telemetry data. - data = pandas.read_csv(fname, index_col=False, names=expected_col_names) + data = pandas.read_csv( + fname, index_col=False, names=expected_col_names) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") except FileNotFoundError as ex: - _LOG.warning( - "Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex - ) + _LOG.warning("Telemetry CSV file not found: %s :: %s", self._read_telemetry_file, ex) return (status, timestamp, []) _LOG.debug("Read telemetry data:\n%s", data) col_dtypes: Mapping[int, Type] = {0: datetime} - return ( - status, - timestamp, - [ - (pandas.Timestamp(ts).to_pydatetime(), metric, value) - for (ts, metric, value) in data.to_records( - index=False, column_dtypes=col_dtypes - ) - ], - ) + return (status, timestamp, [ + (pandas.Timestamp(ts).to_pydatetime(), metric, value) + for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) + ]) def teardown(self) -> None: """ @@ -303,9 +263,7 @@ def teardown(self) -> None: _LOG.info("Local teardown complete: %s :: %s", self, return_code) super().teardown() - def _local_exec( - self, script: Iterable[str], cwd: Optional[str] = None - ) -> Tuple[int, dict]: + def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple[int, dict]: """ Execute a script locally in the scheduler environment. @@ -325,10 +283,7 @@ def _local_exec( env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd - ) + script, env=env_params, cwd=cwd) if return_code != 0: - _LOG.warning( - "ERROR: Local script returns code %d stderr:\n%s", return_code, stderr - ) + _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) return (return_code, {"stdout": stdout, "stderr": stderr}) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 636c7cb9a5..174afd387c 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -29,15 +29,13 @@ class LocalFileShareEnv(LocalEnv): and uploads/downloads data to the shared file storage. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new application environment with a given config. @@ -61,22 +59,14 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - assert self._service is not None and isinstance( - self._service, SupportsLocalExec - ), "LocalEnv requires a service that supports local execution" + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service - assert self._service is not None and isinstance( - self._service, SupportsFileShareOps - ), "LocalEnv requires a service that supports file upload/download operations" + assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ + "LocalEnv requires a service that supports file upload/download operations" self._file_share_service: SupportsFileShareOps = self._service self._upload = self._template_from_to("upload") @@ -88,14 +78,13 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: of string.Template objects so that we can plug in self._params into it later. """ return [ - (Template(d["from"]), Template(d["to"])) + (Template(d['from']), Template(d['to'])) for d in self.config.get(config_key, []) ] @staticmethod - def _expand( - from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] - ) -> Generator[Tuple[str, str], None, None]: + def _expand(from_to: Iterable[Tuple[Template, Template]], + params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. Return a generator of (str, str) pairs of paths. @@ -105,9 +94,7 @@ def _expand( for (path_from, path_to) in from_to ) - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Run setup scripts locally and upload the scripts and data to the shared storage. @@ -132,14 +119,9 @@ def setup( assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for path_from, path_to in self._expand(self._upload, params): - self._file_share_service.upload( - self._params, - self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir] - ), - path_to, - ) + for (path_from, path_to) in self._expand(self._upload, params): + self._file_share_service.upload(self._params, self._config_loader_service.resolve_path( + path_from, extra_paths=[self._temp_dir]), path_to) return self._is_ready def _download_files(self, ignore_missing: bool = False) -> None: @@ -155,15 +137,11 @@ def _download_files(self, ignore_missing: bool = False) -> None: assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for path_from, path_to in self._expand(self._download, params): + for (path_from, path_to) in self._expand(self._download, params): try: - self._file_share_service.download( - self._params, - path_from, - self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir] - ), - ) + self._file_share_service.download(self._params, + path_from, self._config_loader_service.resolve_path( + path_to, extra_paths=[self._temp_dir])) except FileNotFoundError as ex: _LOG.warning("Cannot download: %s", path_from) if not ignore_missing: diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index c9d6ac7ed3..cc47b95500 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -29,15 +29,13 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment that produces mock benchmark data. @@ -57,13 +55,8 @@ def __init__( service: Service An optional service object. Not used by this class. """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) seed = int(self.config.get("mock_env_seed", -1)) self._random = random.Random(seed or None) if seed >= 0 else None self._range = self.config.get("mock_env_range") @@ -88,14 +81,9 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return result # Simple convex function of all tunable parameters. - score = numpy.mean( - numpy.square( - [ - self._normalized(tunable) - for (tunable, _group) in self._tunable_params - ] - ) - ) + score = numpy.mean(numpy.square([ + self._normalized(tunable) for (tunable, _group) in self._tunable_params + ])) # Add noise and shift the benchmark value from [0, 1] to a given range. noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 @@ -103,11 +91,7 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: if self._range: score = self._range[0] + score * (self._range[1] - self._range[0]) - return ( - Status.SUCCEEDED, - timestamp, - {metric: score for metric in self._metrics}, - ) + return (Status.SUCCEEDED, timestamp, {metric: score for metric in self._metrics}) @staticmethod def _normalized(tunable: Tunable) -> float: @@ -117,13 +101,11 @@ def _normalized(tunable: Tunable) -> float: """ val = None if tunable.is_categorical: - val = tunable.categories.index(tunable.category) / float( - len(tunable.categories) - 1 - ) + val = (tunable.categories.index(tunable.category) / + float(len(tunable.categories) - 1)) elif tunable.is_numerical: - val = (tunable.numerical_value - tunable.range[0]) / float( - tunable.range[1] - tunable.range[0] - ) + val = ((tunable.numerical_value - tunable.range[0]) / + float(tunable.range[1] - tunable.range[0])) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index be18bff2fe..f07575ac86 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -14,10 +14,10 @@ from mlos_bench.environments.remote.vm_env import VMEnv __all__ = [ - "HostEnv", - "NetworkEnv", - "OSEnv", - "RemoteEnv", - "SaaSEnv", - "VMEnv", + 'HostEnv', + 'NetworkEnv', + 'OSEnv', + 'RemoteEnv', + 'SaaSEnv', + 'VMEnv', ] diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index e754fce417..05896c9e60 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -22,15 +22,13 @@ class HostEnv(Environment): Remote host environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for host operations. @@ -51,22 +49,13 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM/host, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - assert self._service is not None and isinstance( - self._service, SupportsHostProvisioning - ), "HostEnv requires a service that supports host provisioning operations" + assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \ + "HostEnv requires a service that supports host provisioning operations" self._host_service: SupportsHostProvisioning = self._service - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if host is ready. (Re)provision and start it, if necessary. @@ -104,9 +93,7 @@ def teardown(self) -> None: _LOG.info("Host tear down: %s", self) (status, params) = self._host_service.deprovision_host(self._params) if status.is_pending(): - (status, _) = self._host_service.wait_host_deployment( - params, is_setup=False - ) + (status, _) = self._host_service.wait_host_deployment(params, is_setup=False) super().teardown() _LOG.debug("Final status of Host deprovisioning: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index ba06e7ad5c..552f1729d9 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -27,15 +27,13 @@ class NetworkEnv(Environment): but no real tuning is expected for it ... yet. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for network operations. @@ -56,26 +54,17 @@ def __init__( An optional service object (e.g., providing methods to deploy a network, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) # Virtual networks can be used for more than one experiment, so by default # we don't attempt to deprovision them. self._deprovision_on_teardown = config.get("deprovision_on_teardown", False) - assert self._service is not None and isinstance( - self._service, SupportsNetworkProvisioning - ), "NetworkEnv requires a service that supports network provisioning" + assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \ + "NetworkEnv requires a service that supports network provisioning" self._network_service: SupportsNetworkProvisioning = self._service - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if network is ready. Provision, if necessary. @@ -102,9 +91,7 @@ def setup( (status, params) = self._network_service.provision_network(self._params) if status.is_pending(): - (status, _) = self._network_service.wait_network_deployment( - params, is_setup=True - ) + (status, _) = self._network_service.wait_network_deployment(params, is_setup=True) self._is_ready = status.is_succeeded() return self._is_ready @@ -118,13 +105,9 @@ def teardown(self) -> None: return # Else _LOG.info("Network tear down: %s", self) - (status, params) = self._network_service.deprovision_network( - self._params, ignore_errors=True - ) + (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True) if status.is_pending(): - (status, _) = self._network_service.wait_network_deployment( - params, is_setup=False - ) + (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) super().teardown() _LOG.debug("Final status of Network deprovisioning: %s :: %s", self, status) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index 398c3b65db..ef733c77c2 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -24,15 +24,13 @@ class OSEnv(Environment): OS Level Environment for a host. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for remote execution. @@ -55,27 +53,17 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance( - self._service, SupportsOSOps - ), "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance(self._service, SupportsOSOps), \ + "RemoteEnv requires a service that supports host operations" self._os_service: SupportsOSOps = self._service - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the host is up and running; boot it, if necessary. diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 112d83c4f1..cf38a57b01 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -32,15 +32,13 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for remote execution. @@ -63,30 +61,21 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a Host, VM, OS, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) self._wait_boot = self.config.get("wait_boot", False) - assert self._service is not None and isinstance( - self._service, SupportsRemoteExec - ), "RemoteEnv requires a service that supports remote execution operations" + assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ + "RemoteEnv requires a service that supports remote execution operations" self._remote_exec_service: SupportsRemoteExec = self._service if self._wait_boot: - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Check if the environment is ready and set up the application and benchmarks on a remote host. @@ -163,9 +152,7 @@ def teardown(self) -> None: _LOG.info("Remote teardown complete: %s :: %s", self, status) super().teardown() - def _remote_exec( - self, script: Iterable[str] - ) -> Tuple[Status, datetime, Optional[dict]]: + def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optional[dict]]: """ Run a script on the remote host. @@ -183,8 +170,7 @@ def _remote_exec( env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params - ) + script, config=self._params, env_params=env_params) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: (status, output) = self._remote_exec_service.get_remote_exec_results(output) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index 8885bafc05..b661bfad7e 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -23,15 +23,13 @@ class SaaSEnv(Environment): Cloud-based (configurable) SaaS environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for (configurable) cloud-based SaaS instance. @@ -52,27 +50,18 @@ def __init__( An optional service object (e.g., providing methods to configure the remote service). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance( - self._service, SupportsRemoteConfig - ), "SaaSEnv requires a service that supports remote host configuration API" + assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \ + "SaaSEnv requires a service that supports remote host configuration API" self._config_service: SupportsRemoteConfig = self._service - def setup( - self, tunables: TunableGroups, global_config: Optional[dict] = None - ) -> bool: + def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Update the configuration of a remote SaaS instance. @@ -95,8 +84,7 @@ def setup( return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values() - ) + self._params, self._tunable_params.get_param_values()) if not status.is_succeeded(): return False @@ -105,7 +93,7 @@ def setup( return False # Azure Flex DB instances currently require a VM reboot after reconfiguration. - if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"): + if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'): _LOG.info("Restarting: %s", self) (status, params) = self._host_service.restart_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index d2e4992700..129ac21a0f 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -27,15 +27,13 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for script execution. @@ -65,29 +63,19 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) - self._shell_env_params_rename: Dict[str, str] = self.config.get( - "shell_env_params_rename", {} - ) + self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) results_stdout_pattern = self.config.get("results_stdout_pattern") - self._results_stdout_pattern: Optional[re.Pattern[str]] = ( - re.compile(results_stdout_pattern, flags=re.MULTILINE) - if results_stdout_pattern - else None - ) + self._results_stdout_pattern: Optional[re.Pattern[str]] = \ + re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: """ @@ -127,10 +115,5 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: """ if not self._results_stdout_pattern: return {} - _LOG.debug( - "Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout - ) - return { - key: try_parse_val(val) - for (key, val) in self._results_stdout_pattern.findall(stdout) - } + _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) + return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)} diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index f39f96c5ad..4555ab7f50 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -20,7 +20,7 @@ else: from typing_extensions import TypeAlias -CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name +CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name if sys.version_info >= (3, 9): FutureReturnType: TypeAlias = Future[CoroReturnType] else: @@ -66,14 +66,10 @@ def enter(self) -> None: assert self._event_loop_thread_refcnt == 0 if self._event_loop is None: if sys.platform == "win32": - asyncio.set_event_loop_policy( - asyncio.WindowsSelectorEventLoopPolicy() - ) + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) self._event_loop = asyncio.new_event_loop() assert not self._event_loop.is_running() - self._event_loop_thread = Thread( - target=self._run_event_loop, daemon=True - ) + self._event_loop_thread = Thread(target=self._run_event_loop, daemon=True) self._event_loop_thread.start() self._event_loop_thread_refcnt += 1 @@ -94,9 +90,7 @@ def exit(self) -> None: raise RuntimeError("Failed to stop event loop thread.") self._event_loop_thread = None - def run_coroutine( - self, coro: Coroutine[Any, Any, CoroReturnType] - ) -> FutureReturnType: + def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread and returns a Future that can be used to wait for the result. diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c20ef557d0..c8e48dab69 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -32,9 +32,7 @@ from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO -_LOG_FORMAT = ( - "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" -) +_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) _LOG = logging.getLogger(__name__) @@ -46,9 +44,7 @@ class Launcher: Command line launcher for mlos_bench and mlos_core. """ - def __init__( - self, description: str, long_text: str = "", argv: Optional[List[str]] = None - ): + def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): # pylint: disable=too-many-statements _LOG.info("Launch: %s", description) epilog = """ @@ -58,9 +54,8 @@ def __init__( For additional details, please see the website or the README.md files in the source tree: """ - parser = argparse.ArgumentParser( - description=f"{description} : {long_text}", epilog=epilog - ) + parser = argparse.ArgumentParser(description=f"{description} : {long_text}", + epilog=epilog) (args, args_rest) = self._parse_args(parser, argv) # Bootstrap config loader: command line takes priority. @@ -101,50 +96,38 @@ def __init__( # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: - self.global_config["experiment_id"] = args.experiment_id + self.global_config['experiment_id'] = args.experiment_id # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line if args.trial_config_repeat_count: - self.global_config["trial_config_repeat_count"] = ( - args.trial_config_repeat_count - ) + self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. - self.global_config.setdefault("trial_id", 1) + self.global_config.setdefault('trial_id', 1) - self.global_config = DictTemplater(self.global_config).expand_vars( - use_os_env=True - ) + self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) assert isinstance(self.global_config, dict) # --service cli args should override the config file values. service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) - self._parent_service = self._parent_service.load_services( - service_files, self.global_config, self._parent_service - ) + self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) env_path = args.environment or config.get("environment") if not env_path: _LOG.error("No environment config specified.") - parser.error( - "At least the Environment config must be specified." - + " Run `mlos_bench --help` and consult `README.md` for more info." - ) + parser.error("At least the Environment config must be specified." + + " Run `mlos_bench --help` and consult `README.md` for more info.") self.root_env_config = self._config_loader.resolve_path(env_path) self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, - TunableGroups(), - self.global_config, - service=self._parent_service, - ) + self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) _LOG.info("Init environment: %s", self.environment) # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer self.tunables = self._init_tunable_values( args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, - config.get("tunable_values", []) + (args.tunable_values or []), + config.get("tunable_values", []) + (args.tunable_values or []) ) _LOG.info("Init tunables: %s", self.tunables) @@ -154,11 +137,7 @@ def __init__( self.storage = self._load_storage(args.storage or config.get("storage")) _LOG.info("Init storage: %s", self.storage) - self.teardown: bool = ( - bool(args.teardown) - if args.teardown is not None - else bool(config.get("teardown", True)) - ) + self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) _LOG.info("Init scheduler: %s", self.scheduler) @@ -177,146 +156,87 @@ def service(self) -> Service: return self._parent_service @staticmethod - def _parse_args( - parser: argparse.ArgumentParser, argv: Optional[List[str]] - ) -> Tuple[argparse.Namespace, List[str]]: + def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: """ Parse the command line arguments. """ parser.add_argument( - "--config", - required=False, - help="Main JSON5 configuration file. Its keys are the same as the" - + " command line options and can be overridden by the latter.\n" - + "\n" - + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " - + " for additional config examples for this and other arguments.", - ) + '--config', required=False, + help='Main JSON5 configuration file. Its keys are the same as the' + + ' command line options and can be overridden by the latter.\n' + + '\n' + + ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + + ' for additional config examples for this and other arguments.') parser.add_argument( - "--log_file", - "--log-file", - required=False, - help="Path to the log file. Use stdout if omitted.", - ) + '--log_file', '--log-file', required=False, + help='Path to the log file. Use stdout if omitted.') parser.add_argument( - "--log_level", - "--log-level", - required=False, - type=str, - help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}." - + " Set to DEBUG for debug, WARNING for warnings only.", - ) + '--log_level', '--log-level', required=False, type=str, + help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + + ' Set to DEBUG for debug, WARNING for warnings only.') parser.add_argument( - "--config_path", - "--config-path", - "--config-paths", - "--config_paths", - nargs="+", - action="extend", - required=False, - help="One or more locations of JSON config files.", - ) + '--config_path', '--config-path', '--config-paths', '--config_paths', + nargs="+", action='extend', required=False, + help='One or more locations of JSON config files.') parser.add_argument( - "--service", - "--services", - nargs="+", - action="extend", - required=False, - help="Path to JSON file with the configuration of the service(s) for environment(s) to use.", - ) + '--service', '--services', + nargs='+', action='extend', required=False, + help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') parser.add_argument( - "--environment", - required=False, - help="Path to JSON file with the configuration of the benchmarking environment(s).", - ) + '--environment', required=False, + help='Path to JSON file with the configuration of the benchmarking environment(s).') parser.add_argument( - "--optimizer", - required=False, - help="Path to the optimizer configuration file. If omitted, run" - + " a single trial with default (or specified in --tunable_values).", - ) + '--optimizer', required=False, + help='Path to the optimizer configuration file. If omitted, run' + + ' a single trial with default (or specified in --tunable_values).') parser.add_argument( - "--trial_config_repeat_count", - "--trial-config-repeat-count", - required=False, - type=int, - help="Number of times to repeat each config. Default is 1 trial per config, though more may be advised.", - ) + '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, + help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') parser.add_argument( - "--scheduler", - required=False, - help="Path to the scheduler configuration file. By default, use" - + " a single worker synchronous scheduler.", - ) + '--scheduler', required=False, + help='Path to the scheduler configuration file. By default, use' + + ' a single worker synchronous scheduler.') parser.add_argument( - "--storage", - required=False, - help="Path to the storage configuration file." - + " If omitted, use the ephemeral in-memory SQL storage.", - ) + '--storage', required=False, + help='Path to the storage configuration file.' + + ' If omitted, use the ephemeral in-memory SQL storage.') parser.add_argument( - "--random_init", - "--random-init", - required=False, - default=False, - dest="random_init", - action="store_true", - help="Initialize tunables with random values. (Before applying --tunable_values).", - ) + '--random_init', '--random-init', required=False, default=False, + dest='random_init', action='store_true', + help='Initialize tunables with random values. (Before applying --tunable_values).') parser.add_argument( - "--random_seed", - "--random-seed", - required=False, - type=int, - help="Seed to use with --random_init", - ) + '--random_seed', '--random-seed', required=False, type=int, + help='Seed to use with --random_init') parser.add_argument( - "--tunable_values", - "--tunable-values", - nargs="+", - action="extend", - required=False, - help="Path to one or more JSON files that contain values of the tunable" - + " parameters. This can be used for a single trial (when no --optimizer" - + " is specified) or as default values for the first run in optimization.", - ) + '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, + help='Path to one or more JSON files that contain values of the tunable' + + ' parameters. This can be used for a single trial (when no --optimizer' + + ' is specified) or as default values for the first run in optimization.') parser.add_argument( - "--globals", - nargs="+", - action="extend", - required=False, - help="Path to one or more JSON files that contain additional" - + " [private] parameters of the benchmarking environment.", - ) + '--globals', nargs="+", action='extend', required=False, + help='Path to one or more JSON files that contain additional' + + ' [private] parameters of the benchmarking environment.') parser.add_argument( - "--no_teardown", - "--no-teardown", - required=False, - default=None, - dest="teardown", - action="store_false", - help="Disable teardown of the environment after the benchmark.", - ) + '--no_teardown', '--no-teardown', required=False, default=None, + dest='teardown', action='store_false', + help='Disable teardown of the environment after the benchmark.') parser.add_argument( - "--experiment_id", - "--experiment-id", - required=False, - default=None, + '--experiment_id', '--experiment-id', required=False, default=None, help=""" Experiment ID to use for the benchmark. If omitted, the value from the --cli config or --globals is used. @@ -326,7 +246,7 @@ def _parse_args( changes are made to config files, scripts, versions, etc. This is left as a manual operation as detection of what is "incompatible" is not easily automatable across systems. - """, + """ ) # By default we use the command line arguments, but allow the caller to @@ -368,18 +288,16 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: _LOG.debug("Parsed config: %s", config) return config - def _load_config( - self, - args_globals: Iterable[str], - config_path: Iterable[str], - args_rest: Iterable[str], - global_config: Dict[str, Any], - ) -> Dict[str, Any]: + def _load_config(self, + args_globals: Iterable[str], + config_path: Iterable[str], + args_rest: Iterable[str], + global_config: Dict[str, Any]) -> Dict[str, Any]: """ Get key/value pairs of the global configuration parameters from the specified config files (if any) and command line arguments. """ - for config_file in args_globals or []: + for config_file in (args_globals or []): conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) assert isinstance(conf, dict) global_config.update(conf) @@ -388,9 +306,8 @@ def _load_config( global_config["config_path"] = config_path return global_config - def _init_tunable_values( - self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] - ) -> TunableGroups: + def _init_tunable_values(self, random_init: bool, seed: Optional[int], + args_tunables: Optional[str]) -> TunableGroups: """ Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. @@ -400,17 +317,13 @@ def _init_tunable_values( if random_init: tunables = MockOptimizer( - tunables=tunables, - service=None, - config={"start_with_defaults": False, "seed": seed}, - ).suggest() + tunables=tunables, service=None, + config={"start_with_defaults": False, "seed": seed}).suggest() _LOG.debug("Init tunables: random = %s", tunables) if args_tunables is not None: for data_file in args_tunables: - values = self._config_loader.load_config( - data_file, ConfigSchema.TUNABLE_VALUES - ) + values = self._config_loader.load_config(data_file, ConfigSchema.TUNABLE_VALUES) assert isinstance(values, Dict) tunables.assign(values) _LOG.debug("Init tunables: load %s = %s", data_file, tunables) @@ -426,24 +339,15 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: if args_optimizer is None: # global_config may contain additional properties, so we need to # strip those out before instantiating the basic oneshot optimizer. - config = { - key: val - for key, val in self.global_config.items() - if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS - } + config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} return OneShotOptimizer( - self.tunables, config=config, service=self._parent_service - ) - class_config = self._config_loader.load_config( - args_optimizer, ConfigSchema.OPTIMIZER - ) + self.tunables, config=config, service=self._parent_service) + class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) assert isinstance(class_config, Dict) - optimizer = self._config_loader.build_optimizer( - tunables=self.tunables, - service=self._parent_service, - config=class_config, - global_config=self.global_config, - ) + optimizer = self._config_loader.build_optimizer(tunables=self.tunables, + service=self._parent_service, + config=class_config, + global_config=self.global_config) return optimizer def _load_storage(self, args_storage: Optional[str]) -> Storage: @@ -455,24 +359,17 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: if args_storage is None: # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.storage import SqlStorage - - return SqlStorage( - service=self._parent_service, - config={ - "drivername": "sqlite", - "database": ":memory:", - "lazy_schema_create": True, - }, - ) - class_config = self._config_loader.load_config( - args_storage, ConfigSchema.STORAGE - ) + return SqlStorage(service=self._parent_service, + config={ + "drivername": "sqlite", + "database": ":memory:", + "lazy_schema_create": True, + }) + class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) assert isinstance(class_config, Dict) - storage = self._config_loader.build_storage( - service=self._parent_service, - config=class_config, - global_config=self.global_config, - ) + storage = self._config_loader.build_storage(service=self._parent_service, + config=class_config, + global_config=self.global_config) return storage def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: @@ -487,7 +384,6 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: if args_scheduler is None: # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.sync_scheduler import SyncScheduler - return SyncScheduler( # All config values can be overridden from global config config={ @@ -503,9 +399,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: storage=self.storage, root_env_config=self.root_env_config, ) - class_config = self._config_loader.load_config( - args_scheduler, ConfigSchema.SCHEDULER - ) + class_config = self._config_loader.load_config(args_scheduler, ConfigSchema.SCHEDULER) assert isinstance(class_config, Dict) return self._config_loader.build_scheduler( config=class_config, diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index a61b55d440..f10fa3c82e 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -12,8 +12,8 @@ from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer __all__ = [ - "Optimizer", - "MockOptimizer", - "OneShotOptimizer", - "MlosCoreOptimizer", + 'Optimizer', + 'MockOptimizer', + 'OneShotOptimizer', + 'MlosCoreOptimizer', ] diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index b67ebbfbd9..b9df1db1b7 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -26,7 +26,7 @@ _LOG = logging.getLogger(__name__) -class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes +class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ An abstract interface between the benchmarking framework and mlos_core optimizers. """ @@ -39,13 +39,11 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attribu "start_with_defaults", } - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): """ Create a new optimizer for the given configuration space defined by the tunables. @@ -69,30 +67,25 @@ def __init__( self._seed = int(config.get("seed", 42)) self._in_context = False - experiment_id = self._global_config.get("experiment_id") + experiment_id = self._global_config.get('experiment_id') self.experiment_id = str(experiment_id).strip() if experiment_id else None self._iter = 0 # If False, use the optimizer to suggest the initial configuration; # if True (default), use the already initialized values for the first iteration. self._start_with_defaults: bool = bool( - strtobool(str(self._config.pop("start_with_defaults", True))) - ) - self._max_iter = int(self._config.pop("max_suggestions", 100)) + strtobool(str(self._config.pop('start_with_defaults', True)))) + self._max_iter = int(self._config.pop('max_suggestions', 100)) - opt_targets: Dict[str, str] = self._config.pop( - "optimization_targets", {"score": "min"} - ) + opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'}) self._opt_targets: Dict[str, Literal[1, -1]] = {} - for opt_target, opt_dir in opt_targets.items(): + for (opt_target, opt_dir) in opt_targets.items(): if opt_dir == "min": self._opt_targets[opt_target] = 1 elif opt_dir == "max": self._opt_targets[opt_target] = -1 else: - raise ValueError( - f"Invalid optimization direction: {opt_dir} for {opt_target}" - ) + raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}") def _validate_json_config(self, config: dict) -> None: """ @@ -114,7 +107,7 @@ def __repr__(self) -> str: ) return f"{self.name}({opt_targets},config={self._config})" - def __enter__(self) -> "Optimizer": + def __enter__(self) -> 'Optimizer': """ Enter the optimizer's context. """ @@ -123,12 +116,9 @@ def __enter__(self) -> "Optimizer": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the optimizer. """ @@ -200,9 +190,7 @@ def config_space(self) -> ConfigurationSpace: The ConfigSpace representation of the tunable parameters. """ if self._config_space is None: - self._config_space = tunable_groups_to_configspace( - self._tunables, self._seed - ) + self._config_space = tunable_groups_to_configspace(self._tunables, self._seed) _LOG.debug("ConfigSpace: %s", self._config_space) return self._config_space @@ -215,7 +203,7 @@ def name(self) -> str: return self.__class__.__name__ @property - def targets(self) -> Dict[str, Literal["min", "max"]]: + def targets(self) -> Dict[str, Literal['min', 'max']]: """ A dictionary of {target: direction} of optimization targets. """ @@ -232,12 +220,10 @@ def supports_preload(self) -> bool: return True @abstractmethod - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. @@ -255,12 +241,8 @@ def bulk_register( is_not_empty : bool True if there is data to register, false otherwise. """ - _LOG.info( - "Update the optimizer with: %d configs, %d scores, %d status values", - len(configs or []), - len(scores or []), - len(status or []), - ) + _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values", + len(configs or []), len(scores or []), len(status or [])) if len(configs or []) != len(scores or []): raise ValueError("Numbers of configs and scores do not match.") if status is not None and len(configs or []) != len(status or []): @@ -289,12 +271,8 @@ def suggest(self) -> TunableGroups: return self._tunables.copy() @abstractmethod - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: """ Register the observation for the given configuration. @@ -315,22 +293,15 @@ def register( Benchmark scores extracted (and possibly transformed) from the dataframe that's being MINIMIZED. """ - _LOG.info( - "Iteration %d :: Register: %s = %s score: %s", - self._iter, - tunables, - status, - score, - ) + _LOG.info("Iteration %d :: Register: %s = %s score: %s", + self._iter, tunables, status, score) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) - def _get_scores( - self, - status: Status, - scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]], - ) -> Optional[Dict[str, float]]: + def _get_scores(self, status: Status, + scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] + ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -359,7 +330,7 @@ def _get_scores( assert scores is not None target_metrics: Dict[str, float] = {} - for opt_target, opt_dir in self._opt_targets.items(): + for (opt_target, opt_dir) in self._opt_targets.items(): val = scores[opt_target] assert val is not None target_metrics[opt_target] = float(val) * opt_dir @@ -374,9 +345,7 @@ def not_converged(self) -> bool: return self._iter < self._max_iter @abstractmethod - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 6dc24c01d9..62341c613d 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -48,8 +48,7 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 -) -> ConfigurationSpace: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, wrapped in a ConfigurationSpace for composability. @@ -72,19 +71,14 @@ def _tunable_to_configspace( meta = {"group": group_name, "cost": cost} # {"scaling": ""} if tunable.type == "categorical": - return ConfigurationSpace( - { - tunable.name: CategoricalHyperparameter( - name=tunable.name, - choices=tunable.categories, - weights=( - _normalize_weights(tunable.weights) if tunable.weights else None - ), - default_value=tunable.default, - meta=meta, - ) - } - ) + return ConfigurationSpace({ + tunable.name: CategoricalHyperparameter( + name=tunable.name, + choices=tunable.categories, + weights=_normalize_weights(tunable.weights) if tunable.weights else None, + default_value=tunable.default, + meta=meta) + }) distribution: Union[Uniform, Normal, Beta, None] = None if tunable.distribution == "uniform": @@ -92,12 +86,12 @@ def _tunable_to_configspace( elif tunable.distribution == "normal": distribution = Normal( mu=tunable.distribution_params["mu"], - sigma=tunable.distribution_params["sigma"], + sigma=tunable.distribution_params["sigma"] ) elif tunable.distribution == "beta": distribution = Beta( alpha=tunable.distribution_params["alpha"], - beta=tunable.distribution_params["beta"], + beta=tunable.distribution_params["beta"] ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -109,26 +103,22 @@ def _tunable_to_configspace( log=bool(tunable.is_log), q=nullable(int, tunable.quantization), distribution=distribution, - default=( - int(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None - ), - meta=meta, + default=(int(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None), + meta=meta ) elif tunable.type == "float": range_hp = Float( name=tunable.name, bounds=tunable.range, log=bool(tunable.is_log), - q=tunable.quantization, # type: ignore[arg-type] + q=tunable.quantization, # type: ignore[arg-type] distribution=distribution, # type: ignore[arg-type] - default=( - float(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None - ), - meta=meta, + default=(float(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None), + meta=meta ) else: raise TypeError(f"Invalid Parameter Type: {tunable.type}") @@ -141,50 +131,36 @@ def _tunable_to_configspace( switch_weights = [0.5, 0.5] # FLAML requires uniform weights. if tunable.weights and tunable.range_weight is not None: special_weights = _normalize_weights(tunable.weights) - switch_weights = _normalize_weights( - [sum(tunable.weights), tunable.range_weight] - ) + switch_weights = _normalize_weights([sum(tunable.weights), tunable.range_weight]) # Create three hyperparameters: one for regular values, # one for special values, and one to choose between the two. (special_name, type_name) = special_param_names(tunable.name) - conf_space = ConfigurationSpace( - { - tunable.name: range_hp, - special_name: CategoricalHyperparameter( - name=special_name, - choices=tunable.special, - weights=special_weights, - default_value=( - tunable.default if tunable.default in tunable.special else None - ), - meta=meta, - ), - type_name: CategoricalHyperparameter( - name=type_name, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], - weights=switch_weights, - default_value=TunableValueKind.SPECIAL, - ), - } - ) - conf_space.add_condition( - EqualsCondition( - conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL - ) - ) - conf_space.add_condition( - EqualsCondition( - conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE - ) - ) + conf_space = ConfigurationSpace({ + tunable.name: range_hp, + special_name: CategoricalHyperparameter( + name=special_name, + choices=tunable.special, + weights=special_weights, + default_value=tunable.default if tunable.default in tunable.special else None, + meta=meta + ), + type_name: CategoricalHyperparameter( + name=type_name, + choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + weights=switch_weights, + default_value=TunableValueKind.SPECIAL, + ), + }) + conf_space.add_condition(EqualsCondition( + conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)) + conf_space.add_condition(EqualsCondition( + conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)) return conf_space -def tunable_groups_to_configspace( - tunables: TunableGroups, seed: Optional[int] = None -) -> ConfigurationSpace: +def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -202,14 +178,11 @@ def tunable_groups_to_configspace( A new ConfigurationSpace instance that corresponds to the input TunableGroups. """ space = ConfigurationSpace(seed=seed) - for tunable, group in tunables: + for (tunable, group) in tunables: space.add_configuration_space( - prefix="", - delimiter="", + prefix="", delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost() - ), - ) + tunable, group.name, group.get_current_cost())) return space @@ -228,7 +201,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: A ConfigSpace Configuration. """ values: Dict[str, TunableValue] = {} - for tunable, _group in tunables: + for (tunable, _group) in tunables: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) if tunable.value in tunable.special: @@ -251,8 +224,7 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: data = data.copy() specials = [ special_param_name_strip(k) - for k in data.keys() - if special_param_name_is_temp(k) + for k in data.keys() if special_param_name_is_temp(k) ] for k in specials: (special_name, type_name) = special_param_names(k) diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 4f5efb6aa7..4f207f5fc9 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -28,13 +28,11 @@ class GridSearchOptimizer(TrackBestOptimizer): Grid search optimizer. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) # Track the grid as a set of tuples of tunable values and reconstruct the @@ -53,21 +51,11 @@ def __init__( def _sanity_check(self) -> None: size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) if size == np.inf: - raise ValueError( - f"Unquantized tunables are not supported for grid search: {self._tunables}" - ) + raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") if size > 10000: - _LOG.warning( - "Large number %d of config points requested for grid search: %s", - size, - self._tunables, - ) + _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) if size > self._max_iter: - _LOG.warning( - "Grid search size %d, is greater than max iterations %d", - size, - self._max_iter, - ) + _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: """ @@ -80,14 +68,12 @@ def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], Non # names instead of the order given by TunableGroups. configs = [ configspace_data_to_tunable_values(dict(config)) - for config in generate_grid( - self.config_space, - { - tunable.name: int(tunable.cardinality) - for (tunable, _group) in self._tunables - if tunable.quantization or tunable.type == "int" - }, - ) + for config in + generate_grid(self.config_space, { + tunable.name: int(tunable.cardinality) + for (tunable, _group) in self._tunables + if tunable.quantization or tunable.type == "int" + }) ] names = set(tuple(configs.keys()) for configs in configs) assert len(names) == 1 @@ -103,10 +89,7 @@ def pending_configs(self) -> Iterable[Dict[str, TunableValue]]: Iterable[Dict[str, TunableValue]] """ # See NOTEs above. - return ( - dict(zip(self._config_keys, config)) - for config in self._pending_configs.keys() - ) + return (dict(zip(self._config_keys, config)) for config in self._pending_configs.keys()) @property def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: @@ -118,21 +101,17 @@ def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: Iterable[Dict[str, TunableValue]] """ # See NOTEs above. - return ( - dict(zip(self._config_keys, config)) for config in self._suggested_configs - ) - - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) + + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for params, score, trial_status in zip(configs, scores, status): + for (params, score, trial_status) in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -173,34 +152,20 @@ def suggest(self) -> TunableGroups: _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) try: - config = dict( - ConfigSpace.Configuration( - self.config_space, values=tunables.get_param_values() - ) - ) + config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) self._suggested_configs.remove(tuple(config.values())) except KeyError: - _LOG.warning( - "Attempted to remove missing config (previously registered?) from suggested set: %s", - tunables, - ) + _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) return registered_score def not_converged(self) -> bool: if self._iter > self._max_iter: if bool(self._pending_configs): - _LOG.warning( - "Exceeded max iterations, but still have %d pending configs: %s", - len(self._pending_configs), - list(self._pending_configs.keys()), - ) + _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", + len(self._pending_configs), list(self._pending_configs.keys())) return False return bool(self._pending_configs) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index c30134d1b1..d7d50f1ca5 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -40,42 +40,35 @@ class MlosCoreOptimizer(Optimizer): A wrapper class for the mlos_core optimizers. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) - opt_type = getattr( - OptimizerType, - self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name), - ) + opt_type = getattr(OptimizerType, self._config.pop( + 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name)) if opt_type == OptimizerType.SMAC: - output_directory = self._config.get("output_directory") + output_directory = self._config.get('output_directory') if output_directory is not None: # If output_directory is specified, turn it into an absolute path. - self._config["output_directory"] = os.path.abspath(output_directory) + self._config['output_directory'] = os.path.abspath(output_directory) else: - _LOG.warning( - "SMAC optimizer output_directory was null. SMAC will use a temporary directory." - ) + _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.") # Make sure max_trials >= max_iterations. - if "max_trials" not in self._config: - self._config["max_trials"] = self._max_iter - assert ( - int(self._config["max_trials"]) >= self._max_iter - ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" + if 'max_trials' not in self._config: + self._config['max_trials'] = self._max_iter + assert int(self._config['max_trials']) >= self._max_iter, \ + f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" - if "run_name" not in self._config and self.experiment_id: - self._config["run_name"] = self.experiment_id + if 'run_name' not in self._config and self.experiment_id: + self._config['run_name'] = self.experiment_id - space_adapter_type = self._config.pop("space_adapter_type", None) - space_adapter_config = self._config.pop("space_adapter_config", {}) + space_adapter_type = self._config.pop('space_adapter_type', None) + space_adapter_config = self._config.pop('space_adapter_config', {}) if space_adapter_type is not None: space_adapter_type = getattr(SpaceAdapterType, space_adapter_type) @@ -89,12 +82,9 @@ def __init__( space_adapter_kwargs=space_adapter_config, ) - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: self._opt.cleanup() return super().__exit__(ex_type, ex_val, ex_tb) @@ -102,12 +92,10 @@ def __exit__( def name(self) -> str: return f"{self.__class__.__name__}:{self._opt.__class__.__name__}" - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False @@ -115,8 +103,7 @@ def bulk_register( df_configs = self._to_df(configs) # Impute missing values, if necessary df_scores = self._adjust_signs_df( - pd.DataFrame([{} if score is None else score for score in scores]) - ) + pd.DataFrame([{} if score is None else score for score in scores])) opt_targets = list(self._opt_targets) if status is not None: @@ -131,9 +118,7 @@ def bulk_register( # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register( - configs=df_configs, scores=df_scores[opt_targets].astype(float) - ) + self._opt.register(configs=df_configs, scores=df_scores[opt_targets].astype(float)) if _LOG.isEnabledFor(logging.DEBUG): (score, _) = self.get_best_observation() @@ -145,7 +130,7 @@ def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: """ In-place adjust the signs of the scores for MINIMIZATION problem. """ - for opt_target, opt_dir in self._opt_targets.items(): + for (opt_target, opt_dir) in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores @@ -167,7 +152,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: df_configs = pd.DataFrame(configs) tunables_names = list(self._tunables.get_param_values().keys()) missing_cols = set(tunables_names).difference(df_configs.columns) - for tunable, _group in self._tunables: + for (tunable, _group) in self._tunables: if tunable.name in missing_cols: df_configs[tunable.name] = tunable.default else: @@ -178,9 +163,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) tunables_names += [special_name, type_name] - is_special = df_configs[tunable.name].apply( - tunable.special.__contains__ - ) + is_special = df_configs[tunable.name].apply(tunable.special.__contains__) df_configs[type_name] = TunableValueKind.RANGE df_configs.loc[is_special, type_name] = TunableValueKind.SPECIAL if tunable.type == "int": @@ -202,32 +185,21 @@ def suggest(self) -> TunableGroups: self._start_with_defaults = False _LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config) return tunables.assign( - configspace_data_to_tunable_values(df_config.loc[0].to_dict()) - ) + configspace_data_to_tunable_values(df_config.loc[0].to_dict())) - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: - registered_score = super().register( - tunables, status, score - ) # Sign-adjusted for MINIMIZATION + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None df_config = self._to_df([tunables.get_param_values()]) _LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config) # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register( - configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) - ) + self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float)) return registered_score - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: (df_config, df_score, _df_context) = self._opt.get_best_observations() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 8dd13eb182..ada4411b58 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -24,13 +24,11 @@ class MockOptimizer(TrackBestOptimizer): Mock optimizer to test the Environment API. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) rnd = random.Random(self.seed) self._random: Dict[str, Callable[[Tunable], TunableValue]] = { @@ -39,17 +37,15 @@ def __init__( "int": lambda tunable: rnd.randint(*tunable.range), } - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for params, score, trial_status in zip(configs, scores, status): + for (params, score, trial_status) in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -66,7 +62,7 @@ def suggest(self) -> TunableGroups: _LOG.info("Use default tunable values") self._start_with_defaults = False else: - for tunable, _group in tunables: + for (tunable, _group) in tunables: tunable.value = self._random[tunable.type](tunable) _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index b7a14f8af2..9ad1070c46 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -24,13 +24,11 @@ class OneShotOptimizer(MockOptimizer): # TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344 - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) _LOG.info("Run a single iteration for: %s", self._tunables) self._max_iter = 1 # Always run for just one iteration. diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index 0fd54b2dfa..32a23142e3 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -24,23 +24,17 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): Base Optimizer class that keeps track of the best score and configuration. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) self._best_config: Optional[TunableGroups] = None self._best_score: Optional[Dict[str, float]] = None - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) if status.is_succeeded() and self._is_better(registered_score): self._best_score = registered_score @@ -54,7 +48,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: if self._best_score is None: return True assert registered_score is not None - for opt_target, best_score in self._best_score.items(): + for (opt_target, best_score) in self._best_score.items(): score = registered_score[opt_target] if score < best_score: return True @@ -62,9 +56,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: return False return False - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) score = self._get_scores(Status.SUCCEEDED, self._best_score) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index 7f26851c6b..a7912688a1 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,19 +22,16 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[ - str - ] # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object else: - EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access + EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 -if sys.platform == "win32": +if sys.platform == 'win32': import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) - environ: EnvironType = nt.environ else: environ: EnvironType = os.environ -__all__ = ["environ"] +__all__ = ['environ'] diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py index 3dc5cbbfd4..85c8c2b0c5 100755 --- a/mlos_bench/mlos_bench/run.py +++ b/mlos_bench/mlos_bench/run.py @@ -20,13 +20,10 @@ _LOG = logging.getLogger(__name__) -def _main( - argv: Optional[List[str]] = None, -) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: +def _main(argv: Optional[List[str]] = None + ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: - launcher = Launcher( - "mlos_bench", "Systems autotuning and benchmarking tool", argv=argv - ) + launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv) with launcher.scheduler as scheduler_context: scheduler_context.start() diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index c53d11231d..c54e3c0efc 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ - "Scheduler", - "SyncScheduler", + 'Scheduler', + 'SyncScheduler', ] diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index c089ff5946..0b6733e423 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -31,16 +31,13 @@ class Scheduler(metaclass=ABCMeta): Base class for the optimization loop scheduling policies. """ - def __init__( - self, - *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: Storage, - root_env_config: str, - ): + def __init__(self, *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: Storage, + root_env_config: str): """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service @@ -63,11 +60,8 @@ def __init__( Path to the root environment configuration. """ self.global_config = global_config - config = merge_parameters( - dest=config.copy(), - source=global_config, - required_keys=["experiment_id", "trial_id"], - ) + config = merge_parameters(dest=config.copy(), source=global_config, + required_keys=["experiment_id", "trial_id"]) self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) @@ -75,13 +69,9 @@ def __init__( self._max_trials = int(config.get("max_trials", -1)) self._trial_count = 0 - self._trial_config_repeat_count = int( - config.get("trial_config_repeat_count", 1) - ) + self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) if self._trial_config_repeat_count <= 0: - raise ValueError( - f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" - ) + raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") self._do_teardown = bool(config.get("teardown", True)) @@ -105,7 +95,7 @@ def __repr__(self) -> str: """ return self.__class__.__name__ - def __enter__(self) -> "Scheduler": + def __enter__(self) -> 'Scheduler': """ Enter the scheduler's context. """ @@ -127,12 +117,10 @@ def __enter__(self) -> "Scheduler": ).__enter__() return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the scheduler. """ @@ -154,12 +142,8 @@ def start(self) -> None: Start the optimization loop. """ assert self.experiment is not None - _LOG.info( - "START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, - self.environment, - self.optimizer, - ) + _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", + self.experiment, self.environment, self.optimizer) if _LOG.isEnabledFor(logging.INFO): _LOG.info("Root Environment:\n%s", self.environment.pprint()) @@ -176,9 +160,7 @@ def teardown(self) -> None: if self._do_teardown: self.environment.teardown() - def get_best_observation( - self, - ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: + def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: """ Get the best observation from the optimizer. """ @@ -195,9 +177,7 @@ def load_config(self, config_id: int) -> TunableGroups: tunables = self.environment.tunable_params.assign(tunable_values) _LOG.info("Load config from storage: %d", config_id) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2) - ) + _LOG.debug("Config %d ::\n%s", config_id, json.dumps(tunable_values, indent=2)) return tunables def _schedule_new_optimizer_suggestions(self) -> bool: @@ -224,33 +204,27 @@ def schedule_trial(self, tunables: TunableGroups) -> None: Add a configuration to the queue of trials. """ for repeat_i in range(1, self._trial_config_repeat_count + 1): - self._add_trial_to_queue( - tunables, - config={ - # Add some additional metadata to track for the trial such as the - # optimizer config used. - # Note: these values are unfortunately mutable at the moment. - # Consider them as hints of what the config was the trial *started*. - # It is possible that the experiment configs were changed - # between resuming the experiment (since that is not currently - # prevented). - "optimizer": self.optimizer.name, - "repeat_i": repeat_i, - "is_defaults": tunables.is_defaults, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(self.optimizer.targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - }, - }, - ) + self._add_trial_to_queue(tunables, config={ + # Add some additional metadata to track for the trial such as the + # optimizer config used. + # Note: these values are unfortunately mutable at the moment. + # Consider them as hints of what the config was the trial *started*. + # It is possible that the experiment configs were changed + # between resuming the experiment (since that is not currently + # prevented). + "optimizer": self.optimizer.name, + "repeat_i": repeat_i, + "is_defaults": tunables.is_defaults, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(self.optimizer.targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + } + }) - def _add_trial_to_queue( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> None: + def _add_trial_to_queue(self, tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> None: """ Add a configuration to the queue of trials. A wrapper for the `Experiment.new_trial` method. @@ -283,9 +257,4 @@ def run_trial(self, trial: Storage.Trial) -> None: """ assert self.experiment is not None self._trial_count += 1 - _LOG.info( - "QUEUE: Execute trial # %d/%d :: %s", - self._trial_count, - self._max_trials, - trial, - ) + _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 3e196d4d4f..a73a493533 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -53,9 +53,7 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = ( - self.environment.run() - ) # Block and wait for the final result. + (status, timestamp, results) = self.environment.run() # Block and wait for the final result. _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index dacbb88126..bcc7d02d6f 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -11,7 +11,7 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - "Service", - "FileShareService", - "LocalExecService", + 'Service', + 'FileShareService', + 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index 63c222ee45..f00a7a1a00 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -21,13 +21,10 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): An abstract base of all file shares. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new file share with a given config. @@ -45,16 +42,12 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) @abstractmethod - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: """ Downloads contents from a remote share path to a local path. @@ -72,18 +65,11 @@ def download( if True (the default), download the entire directory tree. """ params = params or {} - _LOG.info( - "Download from File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", - remote_path, - local_path, - params, - ) + _LOG.info("Download from File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", remote_path, local_path, params) @abstractmethod - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: """ Uploads contents from a local path to remote share path. @@ -100,10 +86,5 @@ def upload( if True (the default), upload the entire directory tree. """ params = params or {} - _LOG.info( - "Upload to File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", - local_path, - remote_path, - params, - ) + _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", local_path, remote_path, params) diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 724fd6e8f2..e7c9365bf7 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -26,13 +26,11 @@ class Service: """ @classmethod - def new( - cls, - class_name: str, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - ) -> "Service": + def new(cls, + class_name: str, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None) -> "Service": """ Factory method for a new service with a given config. @@ -59,13 +57,11 @@ def new( assert issubclass(cls, Service) return instantiate_from_config(cls, class_name, config, global_config, parent) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new service with a given config. @@ -100,23 +96,13 @@ def __init__( self._config_loader_service = parent if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Service: %s Config:\n%s", self, json.dumps(self.config, indent=2) - ) - _LOG.debug( - "Service: %s Globals:\n%s", - self, - json.dumps(global_config or {}, indent=2), - ) - _LOG.debug( - "Service: %s Parent: %s", self, parent.pprint() if parent else None - ) + _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) + _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) + _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) @staticmethod - def merge_methods( - ext_methods: Union[Dict[str, Callable], List[Callable], None], - local_methods: Union[Dict[str, Callable], List[Callable]], - ) -> Dict[str, Callable]: + def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], + local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. This function is usually called by the derived class constructor @@ -152,12 +138,9 @@ def __enter__(self) -> "Service": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the Service mix-in context. @@ -194,12 +177,9 @@ def _enter_context(self) -> "Service": self._in_context = True return self - def _exit_context( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def _exit_context(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exits the context for this particular Service instance. @@ -285,12 +265,10 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None # Unfortunately, by creating a set, we may destroy the ability to # preserve the context enter/exit order, but hopefully it doesn't # matter. - svc_method.__self__ - for _, svc_method in self._service_methods.items() + svc_method.__self__ for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, "__self__") - and isinstance(svc_method.__self__, Service) + if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) } def export(self) -> Dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 85cc849b0e..cac3216d61 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -59,17 +59,13 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): Collection of methods to deserialize the Environment, Service, and TunableGroups objects. """ - BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace( - "\\", "/" - ) - - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") + + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of config persistence service. @@ -86,22 +82,17 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.resolve_path, - self.load_config, - self.prepare_class_load, - self.build_service, - self.build_environment, - self.load_services, - self.load_environment, - self.load_environment_list, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.resolve_path, + self.load_config, + self.prepare_class_load, + self.build_service, + self.build_environment, + self.load_services, + self.load_environment, + self.load_environment_list, + ]) ) self._config_loader_service = self @@ -129,9 +120,8 @@ def config_paths(self) -> List[str]: """ return list(self._config_path) # make a copy to avoid modifications - def resolve_path( - self, file_path: str, extra_paths: Optional[Iterable[str]] = None - ) -> str: + def resolve_path(self, file_path: str, + extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -161,11 +151,10 @@ def resolve_path( _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config( - self, - json_file_name: str, - schema_type: Optional[ConfigSchema], - ) -> Dict[str, Any]: + def load_config(self, + json_file_name: str, + schema_type: Optional[ConfigSchema], + ) -> Dict[str, Any]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -185,22 +174,16 @@ def load_config( """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) - with open(json_file_name, mode="r", encoding="utf-8") as fh_json: + with open(json_file_name, mode='r', encoding='utf-8') as fh_json: config = json5.load(fh_json) if schema_type is not None: try: schema_type.validate(config) except (ValidationError, SchemaError) as ex: - _LOG.error( - "Failed to validate config %s against schema type %s at %s", - json_file_name, - schema_type.name, - schema_type.value, - ) - raise ValueError( - f"Failed to validate config {json_file_name} against " - + f"schema type {schema_type.name} at {schema_type.value}" - ) from ex + _LOG.error("Failed to validate config %s against schema type %s at %s", + json_file_name, schema_type.name, schema_type.value) + raise ValueError(f"Failed to validate config {json_file_name} against " + + f"schema type {schema_type.name} at {schema_type.value}") from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated # them to avoid passing them on to other objects @@ -211,14 +194,11 @@ def load_config( del config["$schema"] else: _LOG.warning("Config %s is not validated against a schema.", json_file_name) - return config # type: ignore[no-any-return] + return config # type: ignore[no-any-return] - def prepare_class_load( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - ) -> Tuple[str, Dict[str, Any]]: + def prepare_class_load(self, config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. Mix-in the global parameters and resolve the local file system paths, @@ -252,35 +232,25 @@ def prepare_class_load( merge_parameters(dest=class_config, source=global_config) - for key in set(class_config).intersection( - config.get("resolve_config_property_paths", []) - ): + for key in set(class_config).intersection(config.get("resolve_config_property_paths", [])): if isinstance(class_config[key], str): class_config[key] = self.resolve_path(class_config[key]) elif isinstance(class_config[key], (list, tuple)): - class_config[key] = [ - self.resolve_path(path) for path in class_config[key] - ] + class_config[key] = [self.resolve_path(path) for path in class_config[key]] else: raise ValueError(f"Parameter {key} must be a string or a list") if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Instantiating: %s with config:\n%s", - class_name, - json.dumps(class_config, indent=2), - ) + _LOG.debug("Instantiating: %s with config:\n%s", + class_name, json.dumps(class_config, indent=2)) return (class_name, class_config) - def build_optimizer( - self, - *, - tunables: TunableGroups, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - ) -> Optimizer: + def build_optimizer(self, *, + tunables: TunableGroups, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None) -> Optimizer: """ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. @@ -309,24 +279,18 @@ def build_optimizer( if tunables_path is not None: tunables = self._load_tunables(tunables_path, tunables) (class_name, class_config) = self.prepare_class_load(config, global_config) - inst = instantiate_from_config( - Optimizer, - class_name, # type: ignore[type-abstract] - tunables=tunables, - config=class_config, - global_config=global_config, - service=service, - ) + inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract] + tunables=tunables, + config=class_config, + global_config=global_config, + service=service) _LOG.info("Created: Optimizer %s", inst) return inst - def build_storage( - self, - *, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - ) -> "Storage": + def build_storage(self, *, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None) -> "Storage": """ Instantiation of mlos_bench Storage objects. @@ -348,27 +312,20 @@ def build_storage( from mlos_bench.storage.base_storage import ( Storage, # pylint: disable=import-outside-toplevel ) - - inst = instantiate_from_config( - Storage, - class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - service=service, - ) + inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + service=service) _LOG.info("Created: Storage %s", inst) return inst - def build_scheduler( - self, - *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: "Storage", - root_env_config: str, - ) -> "Scheduler": + def build_scheduler(self, *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: "Storage", + root_env_config: str) -> "Scheduler": """ Instantiation of mlos_bench Scheduler. @@ -396,28 +353,22 @@ def build_scheduler( from mlos_bench.schedulers.base_scheduler import ( Scheduler, # pylint: disable=import-outside-toplevel ) - - inst = instantiate_from_config( - Scheduler, - class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - environment=environment, - optimizer=optimizer, - storage=storage, - root_env_config=root_env_config, - ) + inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + environment=environment, + optimizer=optimizer, + storage=storage, + root_env_config=root_env_config) _LOG.info("Created: Scheduler %s", inst) return inst - def build_environment( - self, # pylint: disable=too-many-arguments - config: Dict[str, Any], - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> Environment: + def build_environment(self, # pylint: disable=too-many-arguments + config: Dict[str, Any], + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> Environment: """ Factory method for a new environment with a given config. @@ -446,9 +397,7 @@ def build_environment( An instance of the `Environment` class initialized with `config`. """ env_name = config["name"] - (env_class, env_config) = self.prepare_class_load( - config, global_config, parent_args - ) + (env_class, env_config) = self.prepare_class_load(config, global_config, parent_args) env_services_path = config.get("include_services") if env_services_path is not None: @@ -459,24 +408,16 @@ def build_environment( tunables = self._load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new( - env_name=env_name, - class_name=env_class, - config=env_config, - global_config=global_config, - tunables=tunables, - service=service, - ) + env = Environment.new(env_name=env_name, class_name=env_class, + config=env_config, global_config=global_config, + tunables=tunables, service=service) _LOG.info("Created env: %s :: %s", env_name, env) return env - def _build_standalone_service( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def _build_standalone_service(self, config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -501,12 +442,9 @@ def _build_standalone_service( _LOG.info("Created service: %s", service) return service - def _build_composite_service( - self, - config_list: Iterable[Dict[str, Any]], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -532,21 +470,18 @@ def _build_composite_service( service.register(parent.export()) for config in config_list: - service.register( - self._build_standalone_service(config, global_config, service).export() - ) + service.register(self._build_standalone_service( + config, global_config, service).export()) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Created mix-in service: %s", service) return service - def build_service( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def build_service(self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -568,7 +503,8 @@ def build_service( services from the list plus the parent mix-in. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2)) + _LOG.debug("Build service from config:\n%s", + json.dumps(config, indent=2)) assert isinstance(config, dict) config_list: List[Dict[str, Any]] @@ -583,14 +519,12 @@ def build_service( return self._build_composite_service(config_list, global_config, parent) - def load_environment( - self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> Environment: + def load_environment(self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> Environment: """ Load and build new environment from the config file. @@ -615,18 +549,14 @@ def load_environment( """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) assert isinstance(config, dict) - return self.build_environment( - config, tunables, global_config, parent_args, service - ) + return self.build_environment(config, tunables, global_config, parent_args, service) - def load_environment_list( - self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> List[Environment]: + def load_environment_list(self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -652,17 +582,12 @@ def load_environment_list( """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) return [ - self.build_environment( - config, tunables, global_config, parent_args, service - ) + self.build_environment(config, tunables, global_config, parent_args, service) ] - def load_services( - self, - json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def load_services(self, json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -681,20 +606,16 @@ def load_services( service : Service A collection of service methods. """ - _LOG.info( - "Load services: %s parent: %s", json_file_names, parent.__class__.__name__ - ) + _LOG.info("Load services: %s parent: %s", + json_file_names, parent.__class__.__name__) service = Service({}, global_config, parent) for fname in json_file_names: config = self.load_config(fname, ConfigSchema.SERVICE) - service.register( - self.build_service(config, global_config, service).export() - ) + service.register(self.build_service(config, global_config, service).export()) return service - def _load_tunables( - self, json_file_names: Iterable[str], parent: TunableGroups - ) -> TunableGroups: + def _load_tunables(self, json_file_names: Iterable[str], + parent: TunableGroups) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent TunableGroup. diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index b9d0c267c1..abb87c8b52 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -9,5 +9,5 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - "LocalExecService", + 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 6b9bca1a0c..47534be7b1 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -79,13 +79,11 @@ class LocalExecService(TempDirContextService, SupportsLocalExec): due to reduced dependency management complications vs the target environment. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of a service to run scripts locally. @@ -102,19 +100,14 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.local_exec]), + config, global_config, parent, + self.merge_methods(methods, [self.local_exec]) ) self.abort_on_error = self.config.get("abort_on_error", True) - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -140,9 +133,7 @@ def local_exec( _LOG.debug("Run in directory: %s", temp_dir) for line in script_lines: - (return_code, stdout, stderr) = self._local_exec_script( - line, env, temp_dir - ) + (return_code, stdout, stderr) = self._local_exec_script(line, env, temp_dir) stdout_list.append(stdout) stderr_list.append(stderr) if return_code != 0 and self.abort_on_error: @@ -184,12 +175,9 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: subcmd_tokens.insert(0, sys.executable) return subcmd_tokens - def _local_exec_script( - self, - script_line: str, - env_params: Optional[Mapping[str, "TunableValue"]], - cwd: str, - ) -> Tuple[int, str, str]: + def _local_exec_script(self, script_line: str, + env_params: Optional[Mapping[str, "TunableValue"]], + cwd: str) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -218,7 +206,7 @@ def _local_exec_script( if env_params: env = {key: str(val) for (key, val) in env_params.items()} - if sys.platform == "win32": + if sys.platform == 'win32': # A hack to run Python on Windows with env variables set: env_copy = environ.copy() env_copy["PYTHONPATH"] = "" @@ -226,25 +214,16 @@ def _local_exec_script( env = env_copy try: - if sys.platform != "win32": + if sys.platform != 'win32': cmd = [" ".join(cmd)] _LOG.info("Run: %s", cmd) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Expands to: %s", Template(" ".join(cmd)).safe_substitute(env) - ) + _LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env)) _LOG.debug("Current working dir: %s", cwd) - proc = subprocess.run( - cmd, - env=env or None, - cwd=cwd, - shell=True, - text=True, - check=False, - capture_output=True, - ) + proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, + text=True, check=False, capture_output=True) _LOG.debug("Run: return code = %d", proc.returncode) return (proc.returncode, proc.stdout, proc.stderr) diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index cdfe510799..a0cf3e0e57 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -28,13 +28,11 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): This class is not supposed to be used as a standalone service. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of a service that provides temporary directory context for local exec service. @@ -52,24 +50,18 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.temp_dir_context]), + config, global_config, parent, + self.merge_methods(methods, [self.temp_dir_context]) ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: # expand globals - self._temp_dir = Template(self._temp_dir).safe_substitute( - global_config or {} - ) + self._temp_dir = Template(self._temp_dir).safe_substitute(global_config or {}) # and resolve the path to absolute path self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir) _LOG.info("%s: temp dir: %s", self, self._temp_dir) - def temp_dir_context( - self, path: Optional[str] = None - ) -> Union[TemporaryDirectory, nullcontext]: + def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 12fe62eeb7..61a6c74942 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -13,9 +13,9 @@ from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService __all__ = [ - "AzureAuthService", - "AzureFileShareService", - "AzureNetworkService", - "AzureSaaSConfigService", - "AzureVMService", + 'AzureAuthService', + 'AzureFileShareService', + 'AzureNetworkService', + 'AzureSaaSConfigService', + 'AzureVMService', ] diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index a5a6bc549c..4121446caf 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -27,15 +27,13 @@ class AzureAuthService(Service, SupportsAuth): Helper methods to get access to Azure services. """ - _REQ_INTERVAL = 300 # = 5 min - - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + _REQ_INTERVAL = 300 # = 5 min + + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure authentication services proxy. @@ -52,27 +50,18 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.get_access_token, - self.get_auth_headers, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.get_access_token, + self.get_auth_headers, + ]) ) # This parameter can come from command line as strings, so conversion is needed. - self._req_interval = float( - self.config.get("tokenRequestInterval", self._REQ_INTERVAL) - ) + self._req_interval = float(self.config.get("tokenRequestInterval", self._REQ_INTERVAL)) self._access_token = "RENEW *NOW*" - self._token_expiration_ts = datetime.now( - UTC - ) # Typically, some future timestamp. + self._token_expiration_ts = datetime.now(UTC) # Typically, some future timestamp. # Login as ourselves self._cred: Union[azure_id.AzureCliCredential, azure_id.CertificateCredential] @@ -81,13 +70,12 @@ def __init__( # Verify info required for SP auth early if "spClientId" in self.config: check_required_params( - self.config, - { + self.config, { "spClientId", "keyVaultName", "certName", "tenant", - }, + } ) def _init_sp(self) -> None: @@ -116,9 +104,7 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential( - tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes - ) + self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) def get_access_token(self) -> str: """ @@ -135,9 +121,7 @@ def get_access_token(self) -> str: res = self._cred.get_token("https://management.azure.com/.default") self._token_expiration_ts = datetime.fromtimestamp(res.expires_on, tz=UTC) self._access_token = res.token - _LOG.info( - "Got new accessToken. Expiration time: %s", self._token_expiration_ts - ) + _LOG.info("Got new accessToken. Expiration time: %s", self._token_expiration_ts) return self._access_token def get_auth_headers(self) -> dict: diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index a494867aa0..9f2b504aff 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -29,9 +29,9 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): Helper methods to manage and deploy Azure resources via REST APIs. """ - _POLL_INTERVAL = 4 # seconds - _POLL_TIMEOUT = 300 # seconds - _REQUEST_TIMEOUT = 5 # seconds + _POLL_INTERVAL = 4 # seconds + _POLL_TIMEOUT = 300 # seconds + _REQUEST_TIMEOUT = 5 # seconds _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) @@ -39,21 +39,19 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): # https://docs.microsoft.com/en-us/rest/api/resources/deployments _URL_DEPLOY = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Resources" - + "/deployments/{deployment_name}" - + "?api-version=2022-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Resources" + + "/deployments/{deployment_name}" + + "?api-version=2022-05-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of an Azure Services proxy. @@ -71,50 +69,32 @@ def __init__( """ super().__init__(config, global_config, parent, methods) - check_required_params( - self.config, - [ - "subscription", - "resourceGroup", - ], - ) + check_required_params(self.config, [ + "subscription", + "resourceGroup", + ]) # These parameters can come from command line as strings, so conversion is needed. - self._poll_interval = float( - self.config.get("pollInterval", self._POLL_INTERVAL) - ) + self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) - self._request_timeout = float( - self.config.get("requestTimeout", self._REQUEST_TIMEOUT) - ) - self._total_retries = int( - self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) - ) - self._backoff_factor = float( - self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) - ) + self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) + self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) + self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) self._deploy_template = {} self._deploy_params = {} if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config["deploymentTemplatePath"], schema_type=None - ) + self.config['deploymentTemplatePath'], schema_type=None) assert template is not None and isinstance(template, dict) self._deploy_template = template # Allow for recursive variable expansion as we do with global params and const_args. - deploy_params = DictTemplater( - self.config["deploymentTemplateParameters"] - ).expand_vars(extra_source_dict=global_config) - self._deploy_params = merge_parameters( - dest=deploy_params, source=global_config - ) + deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) + self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) else: - _LOG.info( - "No deploymentTemplatePath provided. Deployment services will be unavailable." - ) + _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") @property def deploy_params(self) -> dict: @@ -149,10 +129,7 @@ def _get_session(self, params: dict) -> requests.Session: session = requests.Session() session.mount( "https://", - HTTPAdapter( - max_retries=Retry(total=total_retries, backoff_factor=backoff_factor) - ), - ) + HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) session.headers.update(self._get_headers()) return session @@ -160,9 +137,8 @@ def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ + "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @staticmethod @@ -177,15 +153,11 @@ def _extract_arm_parameters(json_data: dict) -> dict: """ return { key: val.get("value") - for (key, val) in json_data.get("properties", {}) - .get("parameters", {}) - .items() + for (key, val) in json_data.get("properties", {}).get("parameters", {}).items() if val.get("value") is not None } - def _azure_rest_api_post_helper( - self, params: dict, url: str - ) -> Tuple[Status, dict]: + def _azure_rest_api_post_helper(self, params: dict, url: str) -> Tuple[Status, dict]: """ General pattern for performing an action on an Azure resource via its REST API. @@ -207,9 +179,7 @@ def _azure_rest_api_post_helper( """ _LOG.debug("Request: POST %s", url) - response = requests.post( - url, headers=self._get_headers(), timeout=self._request_timeout - ) + response = requests.post(url, headers=self._get_headers(), timeout=self._request_timeout) _LOG.debug("Response: %s", response) # Logical flow for async operations based on: @@ -257,20 +227,16 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: try: response = session.get(url, timeout=self._request_timeout) except requests.exceptions.ReadTimeout: - _LOG.warning( - "Request timed out after %.2f s: %s", self._request_timeout, url - ) + _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) return Status.RUNNING, {} except requests.exceptions.RequestException as ex: _LOG.exception("Error in request checking operation status", exc_info=ex) return (Status.FAILED, {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") if response.status_code == 200: output = response.json() @@ -303,19 +269,12 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ params = self._set_default_params(params) - _LOG.info( - "Wait for %s to %s", - params.get("deploymentName"), - "provision" if is_setup else "deprovision", - ) + _LOG.info("Wait for %s to %s", params.get("deploymentName"), + "provision" if is_setup else "deprovision") return self._wait_while(self._check_deployment, Status.PENDING, params) - def _wait_while( - self, - func: Callable[[dict], Tuple[Status, dict]], - loop_status: Status, - params: dict, - ) -> Tuple[Status, dict]: + def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], + loop_status: Status, params: dict) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -337,18 +296,12 @@ def _wait_while( """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] - ) + dest=self.config.copy(), source=params, required_keys=["deploymentName"]) poll_period = params.get("pollInterval", self._poll_interval) - _LOG.debug( - "Wait for %s status %s :: poll %.2f timeout %d s", - config["deploymentName"], - loop_status, - poll_period, - self._poll_timeout, - ) + _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", + config["deploymentName"], loop_status, poll_period, self._poll_timeout) ts_timeout = time.time() + self._poll_timeout poll_delay = poll_period @@ -372,9 +325,7 @@ def _wait_while( _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment( - self, params: dict - ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -400,7 +351,7 @@ def _check_deployment( "subscription", "resourceGroup", "deploymentName", - ], + ] ) _LOG.info("Check deployment: %s", config["deploymentName"]) @@ -415,9 +366,7 @@ def _check_deployment( try: response = session.get(url, timeout=self._request_timeout) except requests.exceptions.ReadTimeout: - _LOG.warning( - "Request timed out after %.2f s: %s", self._request_timeout, url - ) + _LOG.warning("Request timed out after %.2f s: %s", self._request_timeout, url) return Status.RUNNING, {} except requests.exceptions.RequestException as ex: _LOG.exception("Error in request checking deployment", exc_info=ex) @@ -463,18 +412,13 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: if not self._deploy_template: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) - config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] - ) + config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) params = merge_parameters(dest=self._deploy_params.copy(), source=params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Deploy: %s merged params ::\n%s", - config["deploymentName"], - json.dumps(params, indent=2), - ) + _LOG.debug("Deploy: %s merged params ::\n%s", + config["deploymentName"], json.dumps(params, indent=2)) url = self._URL_DEPLOY.format( subscription=config["subscription"], @@ -487,29 +431,22 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: "mode": "Incremental", "template": self._deploy_template, "parameters": { - key: {"value": val} - for (key, val) in params.items() + key: {"value": val} for (key, val) in params.items() if key in self._deploy_template.get("parameters", {}) - }, + } } } if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.put( - url, - json=json_req, - headers=self._get_headers(), - timeout=self._request_timeout, - ) + response = requests.put(url, json=json_req, + headers=self._get_headers(), timeout=self._request_timeout) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") else: _LOG.info("Response: %s", response) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index bec45a967d..6ccd4ba09d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -27,13 +27,11 @@ class AzureFileShareService(FileShareService): _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new file share Service for Azure environments with a given config. @@ -51,19 +49,16 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) check_required_params( - self.config, - { + self.config, { "storageAccountName", "storageFileShareName", "storageAccountKey", - }, + } ) self._share_client = ShareClient.from_share_url( @@ -74,9 +69,7 @@ def __init__( credential=self.config["storageAccountKey"], ) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -101,15 +94,11 @@ def download( # Translate into non-Azure exception: raise FileNotFoundError(f"Cannot download: {remote_path}") from ex - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) - def _upload( - self, local_path: str, remote_path: str, recursive: bool, seen: Set[str] - ) -> None: + def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: """ Upload contents from a local path to an Azure file share. This method is called from `.upload()` above. We need it to avoid exposing diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 95e16892cc..d65ee02cfd 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -32,22 +32,20 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Network" - + "/virtualNetwork/{vnet_name}" - + "/delete" - + "?api-version=2023-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Network" + + "/virtualNetwork/{vnet_name}" + + "/delete" + + "?api-version=2023-05-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure Network services proxy. @@ -64,40 +62,28 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - # SupportsNetworkProvisioning - self.provision_network, - self.deprovision_network, - self.wait_network_deployment, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + # SupportsNetworkProvisioning + self.provision_network, + self.deprovision_network, + self.wait_network_deployment, + ]) ) if not self._deploy_template: - raise ValueError( - "AzureNetworkService requires a deployment template:\n" - + f"config={config}\nglobal_config={global_config}" - ) + raise ValueError("AzureNetworkService requires a deployment template:\n" + + f"config={config}\nglobal_config={global_config}") - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" - _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", - params["deploymentName"], - ) + _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) return params - def wait_network_deployment( - self, params: dict, *, is_setup: bool - ) -> Tuple[Status, dict]: + def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -138,9 +124,7 @@ def provision_network(self, params: dict) -> Tuple[Status, dict]: """ return self._provision_resource(params) - def deprovision_network( - self, params: dict, ignore_errors: bool = True - ) -> Tuple[Status, dict]: + def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple[Status, dict]: """ Deprovisions the virtual network on Azure by deleting it. @@ -167,18 +151,15 @@ def deprovision_network( "resourceGroup", "deploymentName", "vnetName", - ], + ] ) _LOG.info("Deprovision Network: %s", config["vnetName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) - (status, results) = self._azure_rest_api_post_helper( - config, - self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vnet_name=config["vnetName"], - ), - ) + (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vnet_name=config["vnetName"], + )) if ignore_errors and status == Status.FAILED: _LOG.warning("Ignoring error: %s", results) status = Status.SUCCEEDED diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index 03928a4b18..a92d279a6d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -32,22 +32,20 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations _URL_CONFIGURE = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/{provider}" - + "/{server_type}/{vm_name}" - + "/{update}" - + "?api-version={api_version}" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/{provider}" + + "/{server_type}/{vm_name}" + + "/{update}" + + "?api-version={api_version}" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure services proxy. @@ -64,20 +62,18 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.configure, self.is_config_pending]), + config, global_config, parent, + self.merge_methods(methods, [ + self.configure, + self.is_config_pending + ]) ) - check_required_params( - self.config, - { - "subscription", - "resourceGroup", - "provider", - }, - ) + check_required_params(self.config, { + "subscription", + "resourceGroup", + "provider", + }) # Provide sane defaults for known DB providers. provider = self.config.get("provider") @@ -104,11 +100,7 @@ def __init__( provider=self.config["provider"], vm_name="{vm_name}", server_type="flexibleServers" if is_flex else "servers", - update=( - "updateConfigurations" - if self._is_batch - else "configurations/{param_name}" - ), + update="updateConfigurations" if self._is_batch else "configurations/{param_name}", api_version=api_version, ) @@ -123,13 +115,10 @@ def __init__( ) # These parameters can come from command line as strings, so conversion is needed. - self._request_timeout = float( - self.config.get("requestTimeout", self._REQUEST_TIMEOUT) - ) + self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - def configure( - self, config: Dict[str, Any], params: Dict[str, Any] - ) -> Tuple[Status, dict]: + def configure(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service. @@ -168,43 +157,32 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"] - ) + dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_get.format(vm_name=config["vmName"]) _LOG.debug("Request: GET %s", url) response = requests.put( - url, headers=self._get_headers(), timeout=self._request_timeout - ) + url, headers=self._get_headers(), timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) if response.status_code != 200: return (Status.FAILED, {}) # Currently, Azure Flex servers require a VM reboot. - return ( - Status.SUCCEEDED, - { - "isConfigPendingReboot": any( - {"False": False, "True": True}[ - val["properties"]["isConfigPendingRestart"] - ] - for val in response.json()["value"] - ) - }, - ) + return (Status.SUCCEEDED, {"isConfigPendingReboot": any( + {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] + for val in response.json()['value'] + )}) def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ + "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() - def _config_one( - self, config: Dict[str, Any], param_name: str, param_value: Any - ) -> Tuple[Status, dict]: + def _config_one(self, config: Dict[str, Any], + param_name: str, param_value: Any) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -224,18 +202,12 @@ def _config_one( Status is one of {PENDING, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"] - ) - url = self._url_config_set.format( - vm_name=config["vmName"], param_name=param_name - ) + dest=self.config.copy(), source=config, required_keys=["vmName"]) + url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) _LOG.debug("Request: PUT %s", url) - response = requests.put( - url, - headers=self._get_headers(), - json={"properties": {"value": str(param_value)}}, - timeout=self._request_timeout, - ) + response = requests.put(url, headers=self._get_headers(), + json={"properties": {"value": str(param_value)}}, + timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) @@ -243,9 +215,8 @@ def _config_one( return (Status.SUCCEEDED, {}) return (Status.FAILED, {}) - def _config_many( - self, config: Dict[str, Any], params: Dict[str, Any] - ) -> Tuple[Status, dict]: + def _config_many(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service one-by-one. (If batch API is not available for it). @@ -263,15 +234,14 @@ def _config_many( A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - for param_name, param_value in params.items(): + for (param_name, param_value) in params.items(): (status, result) = self._config_one(config, param_name, param_value) if not status.is_succeeded(): return (status, result) return (Status.SUCCEEDED, {}) - def _config_batch( - self, config: Dict[str, Any], params: Dict[str, Any] - ) -> Tuple[Status, dict]: + def _config_batch(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Batch update the parameters of an Azure DB service. @@ -289,8 +259,7 @@ def _config_batch( Status is one of {PENDING, SUCCEEDED, FAILED} """ config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"] - ) + dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"]) json_req = { "value": [ @@ -300,12 +269,8 @@ def _config_batch( # "resetAllToDefault": "True" } _LOG.debug("Request: POST %s", url) - response = requests.post( - url, - headers=self._get_headers(), - json=json_req, - timeout=self._request_timeout, - ) + response = requests.post(url, headers=self._get_headers(), + json=json_req, timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 5f79219c08..ddce3cc935 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -26,13 +26,7 @@ _LOG = logging.getLogger(__name__) -class AzureVMService( - AzureDeploymentService, - SupportsHostProvisioning, - SupportsHostOps, - SupportsOSOps, - SupportsRemoteExec, -): +class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): """ Helper methods to manage VMs on Azure. """ @@ -44,35 +38,35 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/start" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/start" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/powerOff" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/powerOff" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/deallocate" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/deallocate" + + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -94,33 +88,31 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/restart" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/restart" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/runCommand" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/runCommand" + + "?api-version=2022-03-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure VM services proxy. @@ -137,31 +129,26 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - # SupportsHostProvisioning - self.provision_host, - self.deprovision_host, - self.deallocate_host, - self.wait_host_deployment, - # SupportsHostOps - self.start_host, - self.stop_host, - self.restart_host, - self.wait_host_operation, - # SupportsOSOps - self.shutdown, - self.reboot, - self.wait_os_operation, - # SupportsRemoteExec - self.remote_exec, - self.get_remote_exec_results, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + # SupportsHostProvisioning + self.provision_host, + self.deprovision_host, + self.deallocate_host, + self.wait_host_deployment, + # SupportsHostOps + self.start_host, + self.stop_host, + self.restart_host, + self.wait_host_operation, + # SupportsOSOps + self.shutdown, + self.reboot, + self.wait_os_operation, + # SupportsRemoteExec + self.remote_exec, + self.get_remote_exec_results, + ]) ) # As a convenience, allow reading customData out of a file, rather than @@ -170,29 +157,22 @@ def __init__( # can be done using the `base64()` string function inside the ARM template. self._custom_data_file = self.config.get("customDataFile", None) if self._custom_data_file: - if self._deploy_params.get("customData", None): + if self._deploy_params.get('customData', None): raise ValueError("Both customDataFile and customData are specified.") - self._custom_data_file = self.config_loader_service.resolve_path( - self._custom_data_file - ) - with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: + self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file) + with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh: self._deploy_params["customData"] = custom_data_fh.read() - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" - _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", - params["deploymentName"], - ) + _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) return params - def wait_host_deployment( - self, params: dict, *, is_setup: bool - ) -> Tuple[Status, dict]: + def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: """ Waits for a pending operation on an Azure VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -284,19 +264,16 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: "resourceGroup", "deploymentName", "vmName", - ], + ] ) _LOG.info("Deprovision VM: %s", config["vmName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) # TODO: Properly deprovision *all* resources specified in the ARM template. - return self._azure_rest_api_post_helper( - config, - self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -324,17 +301,14 @@ def deallocate_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Deallocate VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_DEALLOCATE.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def start_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -359,17 +333,14 @@ def start_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Start VM: %s :: %s", config["vmName"], params) - return self._azure_rest_api_post_helper( - config, - self._URL_START.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_START.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: """ @@ -396,17 +367,14 @@ def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Stop VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_STOP.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_STOP.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.stop_host(params, force) @@ -436,24 +404,20 @@ def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict] "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Reboot VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_REBOOT.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple[Status, dict]: + def remote_exec(self, script: Iterable[str], config: dict, + env_params: dict) -> Tuple[Status, dict]: """ Run a command on Azure VM. @@ -483,20 +447,16 @@ def remote_exec( "subscription", "resourceGroup", "vmName", - ], + ] ) if _LOG.isEnabledFor(logging.INFO): - _LOG.info( - "Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script) - ) + _LOG.info("Run a script on VM: %s\n %s", config["vmName"], "\n ".join(script)) json_req = { "commandId": "RunShellScript", "script": list(script), - "parameters": [ - {"name": key, "value": val} for (key, val) in env_params.items() - ], + "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()] } url = self._URL_REXEC_RUN.format( @@ -509,18 +469,12 @@ def remote_exec( _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, - json=json_req, - headers=self._get_headers(), - timeout=self._request_timeout, - ) + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") else: _LOG.info("Response: %s", response) @@ -528,13 +482,10 @@ def remote_exec( # TODO: extract the results from JSON response return (Status.SUCCEEDED, config) elif response.status_code == 202: - return ( - Status.PENDING, - { - **config, - "asyncResultsUrl": response.headers.get("Azure-AsyncOperation"), - }, - ) + return (Status.PENDING, { + **config, + "asyncResultsUrl": response.headers.get("Azure-AsyncOperation") + }) else: _LOG.error("Response: %s :: %s", response, response.text) # _LOG.error("Bad Request:\n%s", response.request.body) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index 19290886e4..f623cdfcc8 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -31,14 +31,9 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" - async def _start_file_copy( - self, - params: dict, - mode: CopyMode, - local_path: str, - remote_path: str, - recursive: bool = True, - ) -> None: + async def _start_file_copy(self, params: dict, mode: CopyMode, + local_path: str, remote_path: str, + recursive: bool = True) -> None: # pylint: disable=too-many-arguments """ Starts a file copy operation @@ -76,74 +71,44 @@ async def _start_file_copy( dstpath = (connection, remote_path) else: raise ValueError(f"Unknown copy mode: {mode}") - return await scp( - srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True - ) + return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ], + ] ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy( - params, CopyMode.DOWNLOAD, local_path, remote_path, recursive - ) - ) + self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error( - "Failed to download %s to %s from %s: %s", - remote_path, - local_path, - params, - ex, - ) + _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) if isinstance(ex, SFTPNoSuchFile) or ( - isinstance(ex, SFTPFailure) - and ex.code == 4 - and any( - msg.lower() in ex.reason.lower() - for msg in ("File not found", "No such file or directory") - ) + isinstance(ex, SFTPFailure) and ex.code == 4 + and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) ): _LOG.warning("File %s does not exist on %s", remote_path, params) - raise FileNotFoundError( - f"File {remote_path} does not exist on {params}" - ) from ex + raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex raise ex - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ], + ] ) super().upload(params, local_path, remote_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy( - params, CopyMode.UPLOAD, local_path, remote_path, recursive - ) - ) + self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error( - "Failed to upload %s to %s on %s: %s", - local_path, - remote_path, - params, - ex, - ) + _LOG.error("Failed to upload %s to %s on %s: %s", local_path, remote_path, params, ex) raise ex diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index dad7cb971c..a650ff0707 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -29,13 +29,11 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of an SSH Service. @@ -54,25 +52,17 @@ def __init__( # Same methods are also provided by the AzureVMService class # pylint: disable=duplicate-code super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.shutdown, - self.reboot, - self.wait_os_operation, - self.remote_exec, - self.get_remote_exec_results, - ], - ), - ) + config, global_config, parent, + self.merge_methods(methods, [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ])) self._shell = self.config.get("ssh_shell", "/bin/bash") - async def _run_cmd( - self, params: dict, script: Iterable[str], env_params: dict - ) -> SSHCompletedProcess: + async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. @@ -94,22 +84,17 @@ async def _run_cmd( connection, _ = await self._get_client_connection(params) # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. - env_script_lines = [ - f"export {name}='{value}'" for (name, value) in env_params.items() - ] - script_lines = env_script_lines + [ - line_split for line in script for line_split in line.splitlines() - ] + env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] + script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] # Note: connection.run() uses "exec" with a shell by default. - script_str = "\n".join(script_lines) + script_str = '\n'.join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) - return await connection.run( - script_str, check=False, timeout=self._request_timeout, env=env_params - ) + return await connection.run(script_str, + check=False, + timeout=self._request_timeout, + env=env_params) - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple["Status", dict]: + def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -136,11 +121,9 @@ def remote_exec( source=config, required_keys=[ "ssh_hostname", - ], - ) - config["asyncRemoteExecResultsFuture"] = self._run_coroutine( - self._run_cmd(config, script, env_params) + ] ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) return (Status.PENDING, config) def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: @@ -168,22 +151,10 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: try: result = future.result(timeout=self._request_timeout) assert isinstance(result, SSHCompletedProcess) - stdout = ( - result.stdout.decode() - if isinstance(result.stdout, bytes) - else result.stdout - ) - stderr = ( - result.stderr.decode() - if isinstance(result.stderr, bytes) - else result.stderr - ) + stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout + stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr return ( - ( - Status.SUCCEEDED - if result.exit_status == 0 and result.returncode == 0 - else Status.FAILED - ), + Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, { "stdout": stdout, "stderr": stderr, @@ -194,9 +165,7 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: _LOG.error("Failed to get remote exec results: %s", ex) return (Status.FAILED, {"result": result}) - def _exec_os_op( - self, cmd_opts_list: List[str], params: dict - ) -> Tuple[Status, dict]: + def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: """_summary_ Parameters @@ -217,9 +186,9 @@ def _exec_os_op( source=params, required_keys=[ "ssh_hostname", - ], + ] ) - cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) + cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) script = rf""" if [[ $EUID -ne 0 ]]; then sudo=$(command -v sudo) @@ -254,10 +223,10 @@ def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - "shutdown -h now", - "poweroff", - "halt -p", - "systemctl poweroff", + 'shutdown -h now', + 'poweroff', + 'halt -p', + 'systemctl poweroff', ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) @@ -279,11 +248,11 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - "shutdown -r now", - "reboot", - "halt --reboot", - "systemctl reboot", - "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", + 'shutdown -r now', + 'reboot', + 'halt --reboot', + 'systemctl reboot', + 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index ae18ad4834..8bc90eb3da 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -50,8 +50,8 @@ class SshClient(asyncssh.SSHClient): reconnect for each command. """ - _CONNECTION_PENDING = "INIT" - _CONNECTION_LOST = "LOST" + _CONNECTION_PENDING = 'INIT' + _CONNECTION_LOST = 'LOST' def __init__(self, *args: tuple, **kwargs: dict): self._connection_id: str = SshClient._CONNECTION_PENDING @@ -65,7 +65,7 @@ def __repr__(self) -> str: @staticmethod def id_from_connection(connection: SSHClientConnection) -> str: """Gets a unique id repr for the connection.""" - return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access @staticmethod def id_from_params(connect_params: dict) -> str: @@ -79,12 +79,8 @@ def connection_made(self, conn: SSHClientConnection) -> None: Changes the connection_id from _CONNECTION_PENDING to a unique id repr. """ self._conn_event.clear() - _LOG.debug( - "%s: Connection made by %s: %s", - current_thread().name, - conn._options.env, - conn, - ) # pylint: disable=protected-access + _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ + # pylint: disable=protected-access self._connection_id = SshClient.id_from_connection(conn) self._connection = conn self._conn_event.set() @@ -94,19 +90,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self._conn_event.clear() _LOG.debug("%s: %s", current_thread().name, "connection_lost") if exc is None: - _LOG.debug( - "%s: gracefully disconnected ssh from %s: %s", - current_thread().name, - self._connection_id, - exc, - ) + _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) else: - _LOG.debug( - "%s: ssh connection lost on %s: %s", - current_thread().name, - self._connection_id, - exc, - ) + _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) self._connection_id = SshClient._CONNECTION_LOST self._connection = None self._conn_event.set() @@ -118,11 +104,7 @@ async def connection(self) -> Optional[SSHClientConnection]: """ _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) await self._conn_event.wait() - _LOG.debug( - "%s: Connection available for %s", - current_thread().name, - self._connection_id, - ) + _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) return self._connection @@ -163,9 +145,7 @@ def exit(self) -> None: warn(RuntimeWarning("SshClientCache lock was still held on exit.")) self._cache_lock.release() - async def get_client_connection( - self, connect_params: dict - ) -> Tuple[SSHClientConnection, SshClient]: + async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. @@ -179,57 +159,33 @@ async def get_client_connection( Tuple[SSHClientConnection, SshClient] A tuple of (SSHClientConnection, SshClient). """ - _LOG.debug( - "%s: get_client_connection: %s", current_thread().name, connect_params - ) + _LOG.debug("%s: get_client_connection: %s", current_thread().name, connect_params) async with self._cache_lock: connection_id = SshClient.id_from_params(connect_params) client: Union[None, SshClient, asyncssh.SSHClient] _, client = self._cache.get(connection_id, (None, None)) if client: - _LOG.debug( - "%s: Checking cached client %s", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) connection = await client.connection() if not connection: - _LOG.debug( - "%s: Removing stale client connection %s from cache.", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) self._cache.pop(connection_id) # Try to reconnect next. else: - _LOG.debug( - "%s: Using cached client %s", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) if connection_id not in self._cache: - _LOG.debug( - "%s: Establishing client connection to %s", - current_thread().name, - connection_id, - ) - connection, client = await asyncssh.create_connection( - SshClient, **connect_params - ) + _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) + connection, client = await asyncssh.create_connection(SshClient, **connect_params) assert isinstance(client, SshClient) self._cache[connection_id] = (connection, client) - _LOG.debug( - "%s: Created connection to %s.", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Created connection to %s.", current_thread().name, connection_id) return self._cache[connection_id] def cleanup(self) -> None: """ Closes all cached connections. """ - for connection, _ in self._cache.values(): + for (connection, _) in self._cache.values(): connection.close() self._cache = {} @@ -269,28 +225,24 @@ class SshService(Service, metaclass=ABCMeta): _REQUEST_TIMEOUT: Optional[float] = None # seconds - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__(config, global_config, parent, methods) # Make sure that the value we allow overriding on a per-connection # basis are present in the config so merge_parameters can do its thing. - self.config.setdefault("ssh_port", None) - assert isinstance(self.config["ssh_port"], (int, type(None))) - self.config.setdefault("ssh_username", None) - assert isinstance(self.config["ssh_username"], (str, type(None))) - self.config.setdefault("ssh_priv_key_path", None) - assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) + self.config.setdefault('ssh_port', None) + assert isinstance(self.config['ssh_port'], (int, type(None))) + self.config.setdefault('ssh_username', None) + assert isinstance(self.config['ssh_username'], (str, type(None))) + self.config.setdefault('ssh_priv_key_path', None) + assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) # None can be used to disable the request timeout. - self._request_timeout = self.config.get( - "ssh_request_timeout", self._REQUEST_TIMEOUT - ) + self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) self._request_timeout = nullable(float, self._request_timeout) # Prep an initial connect_params. @@ -298,32 +250,24 @@ def __init__( # In general scripted commands shouldn't need a pty and having one # available can confuse some commands, though we may need to make # this configurable in the future. - "request_pty": False, + 'request_pty': False, # By default disable known_hosts checking (since most VMs expected to be dynamically created). - "known_hosts": None, + 'known_hosts': None, } - if "ssh_known_hosts_file" in self.config: - self._connect_params["known_hosts"] = self.config.get( - "ssh_known_hosts_file", None - ) - if isinstance(self._connect_params["known_hosts"], str): - known_hosts_file = os.path.expanduser( - self._connect_params["known_hosts"] - ) + if 'ssh_known_hosts_file' in self.config: + self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) + if isinstance(self._connect_params['known_hosts'], str): + known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) if not os.path.exists(known_hosts_file): - raise ValueError( - f"ssh_known_hosts_file {known_hosts_file} does not exist" - ) - self._connect_params["known_hosts"] = known_hosts_file - if self._connect_params["known_hosts"] is None: + raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") + self._connect_params['known_hosts'] = known_hosts_file + if self._connect_params['known_hosts'] is None: _LOG.info("%s known_hosts checking is disabled per config.", self) - if "ssh_keepalive_interval" in self.config: - keepalive_internal = self.config.get("ssh_keepalive_interval") - self._connect_params["keepalive_interval"] = nullable( - int, keepalive_internal - ) + if 'ssh_keepalive_interval' in self.config: + keepalive_internal = self.config.get('ssh_keepalive_interval') + self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) def _enter_context(self) -> "SshService": # Start the background thread if it's not already running. @@ -333,12 +277,9 @@ def _enter_context(self) -> "SshService": super()._enter_context() return self - def _exit_context( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def _exit_context(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: # Stop the background thread if it's not needed anymore and potentially # cleanup the cache as well. assert self._in_context @@ -354,9 +295,7 @@ def clear_client_cache(cls) -> None: """ cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() - def _run_coroutine( - self, coro: Coroutine[Any, Any, CoroReturnType] - ) -> FutureReturnType: + def _run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: """ Runs the given coroutine in the background event loop thread. @@ -395,32 +334,28 @@ def _get_connect_params(self, params: dict) -> dict: # Start with the base config params. connect_params = self._connect_params.copy() - connect_params["host"] = params["ssh_hostname"] # required + connect_params['host'] = params['ssh_hostname'] # required - if params.get("ssh_port"): - connect_params["port"] = int(params.pop("ssh_port")) - elif self.config["ssh_port"]: - connect_params["port"] = int(self.config["ssh_port"]) + if params.get('ssh_port'): + connect_params['port'] = int(params.pop('ssh_port')) + elif self.config['ssh_port']: + connect_params['port'] = int(self.config['ssh_port']) - if "ssh_username" in params: - connect_params["username"] = str(params.pop("ssh_username")) - elif self.config["ssh_username"]: - connect_params["username"] = str(self.config["ssh_username"]) + if 'ssh_username' in params: + connect_params['username'] = str(params.pop('ssh_username')) + elif self.config['ssh_username']: + connect_params['username'] = str(self.config['ssh_username']) - priv_key_file: Optional[str] = params.get( - "ssh_priv_key_path", self.config["ssh_priv_key_path"] - ) + priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) if not os.path.exists(priv_key_file): raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") - connect_params["client_keys"] = [priv_key_file] + connect_params['client_keys'] = [priv_key_file] return connect_params - async def _get_client_connection( - self, params: dict - ) -> Tuple[SSHClientConnection, SshClient]: + async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) SshClient (connection) for the given connection params. @@ -435,8 +370,4 @@ async def _get_client_connection( The connection and client objects. """ assert self._in_context - return ( - await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( - self._get_connect_params(params) - ) - ) + return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 02bb06e755..725d0c3306 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -18,12 +18,12 @@ from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec __all__ = [ - "SupportsAuth", - "SupportsConfigLoading", - "SupportsFileShareOps", - "SupportsHostProvisioning", - "SupportsLocalExec", - "SupportsNetworkProvisioning", - "SupportsRemoteConfig", - "SupportsRemoteExec", + 'SupportsAuth', + 'SupportsConfigLoading', + 'SupportsFileShareOps', + 'SupportsHostProvisioning', + 'SupportsLocalExec', + 'SupportsNetworkProvisioning', + 'SupportsRemoteConfig', + 'SupportsRemoteExec', ] diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 04d1c44ca9..05853da0a9 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -34,9 +34,8 @@ class SupportsConfigLoading(Protocol): Protocol interface for helper functions to lookup and load configs. """ - def resolve_path( - self, file_path: str, extra_paths: Optional[Iterable[str]] = None - ) -> str: + def resolve_path(self, file_path: str, + extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -54,9 +53,7 @@ def resolve_path( An actual path to the config or script. """ - def load_config( - self, json_file_name: str, schema_type: Optional[ConfigSchema] - ) -> Union[dict, List[dict]]: + def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -75,14 +72,12 @@ def load_config( Free-format dictionary that contains the configuration. """ - def build_environment( - self, # pylint: disable=too-many-arguments - config: dict, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None, - ) -> "Environment": + def build_environment(self, # pylint: disable=too-many-arguments + config: dict, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None) -> "Environment": """ Factory method for a new environment with a given config. @@ -112,13 +107,12 @@ def build_environment( """ def load_environment_list( # pylint: disable=too-many-arguments - self, - json_file_name: str, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None, - ) -> List["Environment"]: + self, + json_file_name: str, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None) -> List["Environment"]: """ Load and build a list of environments from the config file. @@ -143,12 +137,9 @@ def load_environment_list( # pylint: disable=too-many-arguments A list of new benchmarking environments. """ - def load_services( - self, - json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - ) -> "Service": + def load_services(self, json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None) -> "Service": """ Read the configuration files and bundle all service methods from those configs into a single Service object. diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 8252dc17ed..87ec9e49da 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -15,9 +15,7 @@ class SupportsFileShareOps(Protocol): Protocol interface for file share operations. """ - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: """ Downloads contents from a remote share path to a local path. @@ -35,9 +33,7 @@ def download( if True (the default), download the entire directory tree. """ - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 31f1eb8097..77b481e48e 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -36,9 +36,7 @@ def provision_host(self, params: dict) -> Tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ - def wait_host_deployment( - self, params: dict, *, is_setup: bool - ) -> Tuple["Status", dict]: + def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status", dict]: """ Waits for a pending operation on a Host/VM to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 126966c713..c4c5f01ddc 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -32,12 +32,9 @@ class SupportsLocalExec(Protocol): Used in LocalEnv and provided by LocalExecService. """ - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, TunableValue]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -58,9 +55,7 @@ def local_exec( A 3-tuple of return code, stdout, and stderr of the script process. """ - def temp_dir_context( - self, path: Optional[str] = None - ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 5ce5ebb8e4..fb753aa21c 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -36,9 +36,7 @@ def provision_network(self, params: dict) -> Tuple["Status", dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ - def wait_network_deployment( - self, params: dict, *, is_setup: bool - ) -> Tuple["Status", dict]: + def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status", dict]: """ Waits for a pending operation on a Network to resolve to SUCCEEDED or FAILED. Return TIMED_OUT when timing out. @@ -58,9 +56,7 @@ def wait_network_deployment( Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ - def deprovision_network( - self, params: dict, ignore_errors: bool = True - ) -> Tuple["Status", dict]: + def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index 8a414fad8e..c653e10c2b 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -18,9 +18,8 @@ class SupportsRemoteConfig(Protocol): Protocol interface for configuring cloud services. """ - def configure( - self, config: Dict[str, Any], params: Dict[str, Any] - ) -> Tuple["Status", dict]: + def configure(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index f6ca57912a..096cb3c675 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -20,9 +20,8 @@ class SupportsRemoteExec(Protocol): scripts on a remote host OS. """ - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple["Status", dict]: + def remote_exec(self, script: Iterable[str], config: dict, + env_params: dict) -> Tuple["Status", dict]: """ Run a command on remote host OS. diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index 0812270747..9ae5c80f36 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.storage.storage_factory import from_config __all__ = [ - "Storage", - "from_config", + 'Storage', + 'from_config', ] diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index 47581f0725..ce07e44e2b 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -32,15 +32,12 @@ class ExperimentData(metaclass=ABCMeta): RESULT_COLUMN_PREFIX = "result." CONFIG_COLUMN_PREFIX = "config." - def __init__( - self, - *, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str, - ): + def __init__(self, *, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str): self._experiment_id = experiment_id self._description = description self._root_env_config = root_env_config @@ -145,9 +142,9 @@ def default_tunable_config_id(self) -> Optional[int]: trials_items = sorted(self.trials.items()) if not trials_items: return None - for _trial_id, trial in trials_items: + for (_trial_id, trial) in trials_items: # Take the first config id marked as "defaults" when it was instantiated. - if strtobool(str(trial.metadata_dict.get("is_defaults", False))): + if strtobool(str(trial.metadata_dict.get('is_defaults', False))): return trial.tunable_config_id # Fallback (min trial_id) return trials_items[0][1].tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 8167504627..2165fa706f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -30,12 +30,10 @@ class Storage(metaclass=ABCMeta): and storage systems (e.g., SQLite or MLFLow). """ - def __init__( - self, - config: Dict[str, Any], - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + config: Dict[str, Any], + global_config: Optional[dict] = None, + service: Optional[Service] = None): """ Create a new storage object. @@ -76,16 +74,13 @@ def experiments(self) -> Dict[str, ExperimentData]: """ @abstractmethod - def experiment( - self, - *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal["min", "max"]], - ) -> "Storage.Experiment": + def experiment(self, *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment': """ Create a new experiment in the storage. @@ -121,27 +116,23 @@ class Experiment(metaclass=ABCMeta): This class is instantiated in the `Storage.experiment()` method. """ - def __init__( - self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal["min", "max"]], - ): + def __init__(self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal['min', 'max']]): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( - root_env_config - ) + (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) self._description = description self._opt_targets = opt_targets self._in_context = False - def __enter__(self) -> "Storage.Experiment": + def __enter__(self) -> 'Storage.Experiment': """ Enter the context of the experiment. @@ -153,12 +144,9 @@ def __enter__(self) -> "Storage.Experiment": self._in_context = True return self - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Literal[False]: """ End the context of the experiment. @@ -169,11 +157,8 @@ def __exit__( _LOG.debug("Finishing experiment: %s", self) else: assert exc_type and exc_val - _LOG.warning( - "Finishing experiment: %s", - self, - exc_info=(exc_type, exc_val, exc_tb), - ) + _LOG.warning("Finishing experiment: %s", self, + exc_info=(exc_type, exc_val, exc_tb)) assert self._in_context self._teardown(is_ok) self._in_context = False @@ -263,10 +248,8 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: """ @abstractmethod - def load( - self, - last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load(self, last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. @@ -286,9 +269,7 @@ def load( """ @abstractmethod - def pending_trials( - self, timestamp: datetime, *, running: bool - ) -> Iterator["Storage.Trial"]: + def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: """ Return an iterator over the pending trials that are scheduled to run on or before the specified timestamp. @@ -308,12 +289,8 @@ def pending_trials( """ @abstractmethod - def new_trial( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> "Storage.Trial": + def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': """ Create a new experiment run in the storage. @@ -340,16 +317,10 @@ class Trial(metaclass=ABCMeta): This class is instantiated in the `Storage.Experiment.trial()` method. """ - def __init__( - self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - opt_targets: Dict[str, Literal["min", "max"]], - config: Optional[Dict[str, Any]] = None, - ): + def __init__(self, *, + tunables: TunableGroups, experiment_id: str, trial_id: int, + tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']], + config: Optional[Dict[str, Any]] = None): self._tunables = tunables self._experiment_id = experiment_id self._trial_id = trial_id @@ -390,9 +361,7 @@ def tunables(self) -> TunableGroups: """ return self._tunables - def config( - self, global_config: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ Produce a copy of the global configuration updated with the parameters of the current trial. @@ -409,12 +378,9 @@ def config( return config @abstractmethod - def update( - self, - status: Status, - timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None, - ) -> Optional[Dict[str, Any]]: + def update(self, status: Status, timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. @@ -438,21 +404,14 @@ def update( assert metrics is not None opt_targets = set(self._opt_targets.keys()) if not opt_targets.issubset(metrics.keys()): - _LOG.warning( - "Trial %s :: opt.targets missing: %s", - self, - opt_targets.difference(metrics.keys()), - ) + _LOG.warning("Trial %s :: opt.targets missing: %s", + self, opt_targets.difference(metrics.keys())) # raise ValueError() return metrics @abstractmethod - def update_telemetry( - self, - status: Status, - timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]], - ) -> None: + def update_telemetry(self, status: Status, timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]]) -> None: """ Save the experiment's telemetry data and intermediate status. @@ -465,6 +424,4 @@ def update_telemetry( metrics : List[Tuple[datetime, str, Any]] Telemetry data. """ - _LOG.info( - "Store telemetry: %s :: %s %d records", self, status, len(metrics) - ) + _LOG.info("Store telemetry: %s :: %s %d records", self, status, len(metrics)) diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index cc4eebf9df..b3b2bed86a 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -31,23 +31,18 @@ class TrialData(metaclass=ABCMeta): of tunable parameters). """ - def __init__( - self, - *, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status, - ): + def __init__(self, *, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status): self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id assert ts_start.tzinfo == UTC, "ts_start must be in UTC" - assert ( - ts_end is None or ts_end.tzinfo == UTC - ), "ts_end must be in UTC if not None" + assert ts_end is None or ts_end.tzinfo == UTC, "ts_end must be in UTC if not None" self._ts_start = ts_start self._ts_end = ts_end self._status = status @@ -58,10 +53,7 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return ( - self._experiment_id == other._experiment_id - and self._trial_id == other._trial_id - ) + return self._experiment_id == other._experiment_id and self._trial_id == other._trial_id @property def experiment_id(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0c9adce22d..0dce110b1b 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -21,7 +21,8 @@ class TunableConfigData(metaclass=ABCMeta): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, tunable_config_id: int): + def __init__(self, *, + tunable_config_id: int): self._tunable_config_id = tunable_config_id def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 6cabaaf3ba..18c50035a9 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -27,19 +27,14 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__( - self, - *, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None, - ): + def __init__(self, *, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None): self._experiment_id = experiment_id self._tunable_config_id = tunable_config_id # can be lazily initialized as necessary: - self._tunable_config_trial_group_id: Optional[int] = ( - tunable_config_trial_group_id - ) + self._tunable_config_trial_group_id: Optional[int] = tunable_config_trial_group_id @property def experiment_id(self) -> str: @@ -72,9 +67,7 @@ def tunable_config_trial_group_id(self) -> int: config_id. """ if self._tunable_config_trial_group_id is None: - self._tunable_config_trial_group_id = ( - self._get_tunable_config_trial_group_id() - ) + self._tunable_config_trial_group_id = self._get_tunable_config_trial_group_id() assert self._tunable_config_trial_group_id is not None return self._tunable_config_trial_group_id @@ -84,10 +77,7 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return ( - self._tunable_config_id == other._tunable_config_id - and self._experiment_id == other._experiment_id - ) + return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index cf09b9aa5a..735e21bcaf 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -8,5 +8,5 @@ from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ - "SqlStorage", + 'SqlStorage', ] diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 50d944150b..c7ee73a3bc 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -18,11 +18,10 @@ def get_trials( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None, -) -> Dict[int, TrialData]: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -31,18 +30,13 @@ def get_trials( from mlos_bench.storage.sql.trial_data import ( TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import ) - with engine.connect() as conn: # Build up sql a statement for fetching trials. - stmt = ( - schema.trial.select() - .where( - schema.trial.c.exp_id == experiment_id, - ) - .order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), - ) + stmt = schema.trial.select().where( + schema.trial.c.exp_id == experiment_id, + ).order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -66,11 +60,10 @@ def get_trials( def get_results_df( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None, -) -> pandas.DataFrame: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -79,22 +72,15 @@ def get_results_df( # pylint: disable=too-many-locals with engine.connect() as conn: # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. - tunable_config_group_id_stmt = ( - schema.trial.select() - .with_only_columns( - schema.trial.c.exp_id, - schema.trial.c.config_id, - func.min(schema.trial.c.trial_id) - .cast(Integer) - .label("tunable_config_trial_group_id"), - ) - .where( - schema.trial.c.exp_id == experiment_id, - ) - .group_by( - schema.trial.c.exp_id, - schema.trial.c.config_id, - ) + tunable_config_group_id_stmt = schema.trial.select().with_only_columns( + schema.trial.c.exp_id, + schema.trial.c.config_id, + func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'), + ).where( + schema.trial.c.exp_id == experiment_id, + ).group_by( + schema.trial.c.exp_id, + schema.trial.c.config_id, ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -104,24 +90,18 @@ def get_results_df( tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() # Get each trial's metadata. - cur_trials_stmt = ( - select( - schema.trial, - tunable_config_trial_group_id_subquery, - ) - .where( - schema.trial.c.exp_id == experiment_id, - and_( - tunable_config_trial_group_id_subquery.c.exp_id - == schema.trial.c.exp_id, - tunable_config_trial_group_id_subquery.c.config_id - == schema.trial.c.config_id, - ), - ) - .order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), - ) + cur_trials_stmt = select( + schema.trial, + tunable_config_trial_group_id_subquery, + ).where( + schema.trial.c.exp_id == experiment_id, + and_( + tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, + tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, + ), + ).order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -130,48 +110,39 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [ - ( - row.trial_id, - utcify_timestamp(row.ts_start, origin="utc"), - utcify_nullable_timestamp(row.ts_end, origin="utc"), - row.config_id, - row.tunable_config_trial_group_id, - row.status, - ) - for row in cur_trials.fetchall() - ], + [( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) for row in cur_trials.fetchall()], columns=[ - "trial_id", - "ts_start", - "ts_end", - "tunable_config_id", - "tunable_config_trial_group_id", - "status", - ], + 'trial_id', + 'ts_start', + 'ts_end', + 'tunable_config_id', + 'tunable_config_trial_group_id', + 'status', + ] ) # Get each trial's config in wide format. - configs_stmt = ( - schema.trial.select() - .with_only_columns( - schema.trial.c.trial_id, - schema.trial.c.config_id, - schema.config_param.c.param_id, - schema.config_param.c.param_value, - ) - .where( - schema.trial.c.exp_id == experiment_id, - ) - .join( - schema.config_param, - schema.config_param.c.config_id == schema.trial.c.config_id, - isouter=True, - ) - .order_by( - schema.trial.c.trial_id, - schema.config_param.c.param_id, - ) + configs_stmt = schema.trial.select().with_only_columns( + schema.trial.c.trial_id, + schema.trial.c.config_id, + schema.config_param.c.param_id, + schema.config_param.c.param_value, + ).where( + schema.trial.c.exp_id == experiment_id, + ).join( + schema.config_param, + schema.config_param.c.config_id == schema.trial.c.config_id, + isouter=True + ).order_by( + schema.trial.c.trial_id, + schema.config_param.c.param_id, ) if tunable_config_id is not None: configs_stmt = configs_stmt.where( @@ -179,67 +150,41 @@ def get_results_df( ) configs = conn.execute(configs_stmt) configs_df = pandas.DataFrame( - [ - ( - row.trial_id, - row.config_id, - ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, - row.param_value, - ) - for row in configs.fetchall() - ], - columns=["trial_id", "tunable_config_id", "param", "value"], + [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value) + for row in configs.fetchall()], + columns=['trial_id', 'tunable_config_id', 'param', 'value'] ).pivot( - index=["trial_id", "tunable_config_id"], - columns="param", - values="value", + index=["trial_id", "tunable_config_id"], columns="param", values="value", ) - configs_df = configs_df.apply(pandas.to_numeric, errors="coerce").fillna(configs_df) # type: ignore[assignment] # (fp) + configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp) # Get each trial's results in wide format. - results_stmt = ( - schema.trial_result.select() - .with_only_columns( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - schema.trial_result.c.metric_value, - ) - .where( - schema.trial_result.c.exp_id == experiment_id, - ) - .order_by( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - ) + results_stmt = schema.trial_result.select().with_only_columns( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + schema.trial_result.c.metric_value, + ).where( + schema.trial_result.c.exp_id == experiment_id, + ).order_by( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, ) if tunable_config_id is not None: - results_stmt = results_stmt.join( - schema.trial, - and_( - schema.trial.c.exp_id == schema.trial_result.c.exp_id, - schema.trial.c.trial_id == schema.trial_result.c.trial_id, - schema.trial.c.config_id == tunable_config_id, - ), - ) + results_stmt = results_stmt.join(schema.trial, and_( + schema.trial.c.exp_id == schema.trial_result.c.exp_id, + schema.trial.c.trial_id == schema.trial_result.c.trial_id, + schema.trial.c.config_id == tunable_config_id, + )) results = conn.execute(results_stmt) results_df = pandas.DataFrame( - [ - ( - row.trial_id, - ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, - row.metric_value, - ) - for row in results.fetchall() - ], - columns=["trial_id", "metric", "value"], + [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value) + for row in results.fetchall()], + columns=['trial_id', 'metric', 'value'] ).pivot( - index="trial_id", - columns="metric", - values="value", + index="trial_id", columns="metric", values="value", ) - results_df = results_df.apply(pandas.to_numeric, errors="coerce").fillna(results_df) # type: ignore[assignment] # (fp) + results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp) # Concat the trials, configs, and results. - return trials_df.merge( - configs_df, on=["trial_id", "tunable_config_id"], how="left" - ).merge(results_df, on="trial_id", how="left") + return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \ + .merge(results_df, on="trial_id", how="left") diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index e231188f71..58ee3dddb5 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -29,18 +29,15 @@ class Experiment(Storage.Experiment): Logic for retrieving and storing the results of a single experiment. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal["min", "max"]], - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal['min', 'max']]): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -58,22 +55,18 @@ def _setup(self) -> None: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable exp_info = conn.execute( - self._schema.experiment.select() - .with_only_columns( + self._schema.experiment.select().with_only_columns( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, func.max(self._schema.trial.c.trial_id).label("trial_id"), - ) - .join( + ).join( self._schema.trial, self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, - isouter=True, - ) - .where( + isouter=True + ).where( self._schema.experiment.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, @@ -82,47 +75,33 @@ def _setup(self) -> None: if exp_info is None: _LOG.info("Start new experiment: %s", self._experiment_id) # It's a new experiment: create a record for it in the database. - conn.execute( - self._schema.experiment.insert().values( - exp_id=self._experiment_id, - description=self._description, - git_repo=self._git_repo, - git_commit=self._git_commit, - root_env_config=self._root_env_config, - ) - ) - conn.execute( - self._schema.objectives.insert().values( - [ - { - "exp_id": self._experiment_id, - "optimization_target": opt_target, - "optimization_direction": opt_dir, - } - for (opt_target, opt_dir) in self.opt_targets.items() - ] - ) - ) + conn.execute(self._schema.experiment.insert().values( + exp_id=self._experiment_id, + description=self._description, + git_repo=self._git_repo, + git_commit=self._git_commit, + root_env_config=self._root_env_config, + )) + conn.execute(self._schema.objectives.insert().values([ + { + "exp_id": self._experiment_id, + "optimization_target": opt_target, + "optimization_direction": opt_dir, + } + for (opt_target, opt_dir) in self.opt_targets.items() + ])) else: if exp_info.trial_id is not None: self._trial_id = exp_info.trial_id + 1 - _LOG.info( - "Continue experiment: %s last trial: %s resume from: %d", - self._experiment_id, - exp_info.trial_id, - self._trial_id, - ) + _LOG.info("Continue experiment: %s last trial: %s resume from: %d", + self._experiment_id, exp_info.trial_id, self._trial_id) # TODO: Sanity check that certain critical configs (e.g., # objectives) haven't changed to be incompatible such that a new # experiment should be started (possibly by prewarming with the # previous one). if exp_info.git_commit != self._git_commit: - _LOG.warning( - "Experiment %s git expected: %s %s", - self, - exp_info.git_repo, - exp_info.git_commit, - ) + _LOG.warning("Experiment %s git expected: %s %s", + self, exp_info.git_repo, exp_info.git_commit) def merge(self, experiment_ids: List[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) @@ -130,55 +109,38 @@ def merge(self, experiment_ids: List[str]) -> None: def load_tunable_config(self, config_id: int) -> Dict[str, Any]: with self._engine.connect() as conn: - return self._get_key_val( - conn, self._schema.config_param, "param", config_id=config_id - ) + return self._get_key_val(conn, self._schema.config_param, "param", config_id=config_id) def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select() - .where( + self._schema.trial_telemetry.select().where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == trial_id, - ) - .order_by( + self._schema.trial_telemetry.c.trial_id == trial_id + ).order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) ) # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. - return [ - ( - utcify_timestamp(row.ts, origin="utc"), - row.metric_id, - row.metric_value, - ) - for row in cur_telemetry.fetchall() - ] + return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall()] - def load( - self, - last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load(self, last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: with self._engine.connect() as conn: cur_trials = conn.execute( - self._schema.trial.select() - .with_only_columns( + self._schema.trial.select().with_only_columns( self._schema.trial.c.trial_id, self._schema.trial.c.config_id, self._schema.trial.c.status, - ) - .where( + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, - self._schema.trial.c.status.in_( - ["SUCCEEDED", "FAILED", "TIMED_OUT"] - ), - ) - .order_by( + self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), + ).order_by( self._schema.trial.c.trial_id.asc(), ) ) @@ -192,33 +154,19 @@ def load( stat = Status[trial.status] status.append(stat) trial_ids.append(trial.trial_id) - configs.append( - self._get_key_val( - conn, - self._schema.config_param, - "param", - config_id=trial.config_id, - ) - ) + configs.append(self._get_key_val( + conn, self._schema.config_param, "param", config_id=trial.config_id)) if stat.is_succeeded(): - scores.append( - self._get_key_val( - conn, - self._schema.trial_result, - "metric", - exp_id=self._experiment_id, - trial_id=trial.trial_id, - ) - ) + scores.append(self._get_key_val( + conn, self._schema.trial_result, "metric", + exp_id=self._experiment_id, trial_id=trial.trial_id)) else: scores.append(None) return (trial_ids, configs, scores, status) @staticmethod - def _get_key_val( - conn: Connection, table: Table, field: str, **kwargs: Any - ) -> Dict[str, Any]: + def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: """ Helper method to retrieve key-value pairs from the database. (E.g., configurations, results, and telemetry). @@ -227,63 +175,49 @@ def _get_key_val( select( column(f"{field}_id"), column(f"{field}_value"), + ).select_from(table).where( + *[column(key) == val for (key, val) in kwargs.items()] ) - .select_from(table) - .where(*[column(key) == val for (key, val) in kwargs.items()]) ) # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. - return dict( - row._tuple() for row in cur_result.fetchall() - ) # pylint: disable=protected-access + return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access @staticmethod - def _save_params( - conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any - ) -> None: + def _save_params(conn: Connection, table: Table, + params: Dict[str, Any], **kwargs: Any) -> None: if not params: return - conn.execute( - table.insert(), - [ - {**kwargs, "param_id": key, "param_value": nullable(str, val)} - for (key, val) in params.items() - ], - ) + conn.execute(table.insert(), [ + { + **kwargs, + "param_id": key, + "param_value": nullable(str, val) + } + for (key, val) in params.items() + ]) - def pending_trials( - self, timestamp: datetime, *, running: bool - ) -> Iterator[Storage.Trial]: + def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") - _LOG.info( - "Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp - ) + _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) if running: - pending_status = ["PENDING", "READY", "RUNNING"] + pending_status = ['PENDING', 'READY', 'RUNNING'] else: - pending_status = ["PENDING"] + pending_status = ['PENDING'] with self._engine.connect() as conn: - cur_trials = conn.execute( - self._schema.trial.select().where( - self._schema.trial.c.exp_id == self._experiment_id, - ( - self._schema.trial.c.ts_start.is_(None) - | (self._schema.trial.c.ts_start <= timestamp) - ), - self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), - ) - ) + cur_trials = conn.execute(self._schema.trial.select().where( + self._schema.trial.c.exp_id == self._experiment_id, + (self._schema.trial.c.ts_start.is_(None) | + (self._schema.trial.c.ts_start <= timestamp)), + self._schema.trial.c.ts_end.is_(None), + self._schema.trial.c.status.in_(pending_status), + )) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id - ) + conn, self._schema.config_param, "param", + config_id=trial.config_id) config = self._get_key_val( - conn, - self._schema.trial_param, - "param", - exp_id=self._experiment_id, - trial_id=trial.trial_id, - ) + conn, self._schema.trial_param, "param", + exp_id=self._experiment_id, trial_id=trial.trial_id) yield Trial( engine=self._engine, schema=self._schema, @@ -301,59 +235,42 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: Get the config ID for the given tunables. If the config does not exist, create a new record for it. """ - config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() - cur_config = conn.execute( - self._schema.config.select().where( - self._schema.config.c.config_hash == config_hash - ) - ).fetchone() + config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() + cur_config = conn.execute(self._schema.config.select().where( + self._schema.config.c.config_hash == config_hash + )).fetchone() if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute( - self._schema.config.insert().values(config_hash=config_hash) - ).inserted_primary_key[0] + config_id: int = conn.execute(self._schema.config.insert().values( + config_hash=config_hash)).inserted_primary_key[0] self._save_params( - conn, - self._schema.config_param, + conn, self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, - config_id=config_id, - ) + config_id=config_id) return config_id - def new_trial( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> Storage.Trial: + def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> Storage.Trial: ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") - _LOG.debug( - "Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start - ) + _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) - conn.execute( - self._schema.trial.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - config_id=config_id, - ts_start=ts_start, - status="PENDING", - ) - ) + conn.execute(self._schema.trial.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + config_id=config_id, + ts_start=ts_start, + status='PENDING', + )) # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: self._save_params( - conn, - self._schema.trial_param, - config, - exp_id=self._experiment_id, - trial_id=self._trial_id, - ) + conn, self._schema.trial_param, config, + exp_id=self._experiment_id, trial_id=self._trial_id) trial = Trial( engine=self._engine, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index a370ad1060..eaa6e1041f 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -35,17 +35,14 @@ class ExperimentSqlData(ExperimentData): scripts and mlos_bench configuration files. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str): super().__init__( experiment_id=experiment_id, description=description, @@ -60,11 +57,9 @@ def __init__( def objectives(self) -> Dict[str, Literal["min", "max"]]: with self._engine.connect() as conn: objectives_db_data = conn.execute( - self._schema.objectives.select() - .where( + self._schema.objectives.select().where( self._schema.objectives.c.exp_id == self._experiment_id, - ) - .order_by( + ).order_by( self._schema.objectives.c.weight.desc(), self._schema.objectives.c.optimization_target.asc(), ) @@ -85,19 +80,13 @@ def trials(self) -> Dict[int, TrialData]: def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: with self._engine.connect() as conn: tunable_config_trial_groups = conn.execute( - self._schema.trial.select() - .with_only_columns( + self._schema.trial.select().with_only_columns( self._schema.trial.c.config_id, - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label( # pylint: disable=not-callable - "tunable_config_trial_group_id" - ), - ) - .where( + func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + 'tunable_config_trial_group_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -117,14 +106,11 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: def tunable_configs(self) -> Dict[int, TunableConfigData]: with self._engine.connect() as conn: tunable_configs = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -153,30 +139,20 @@ def default_tunable_config_id(self) -> Optional[int]: """ with self._engine.connect() as conn: query_results = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial_param.select() - .with_only_columns( - func.min(self._schema.trial_param.c.trial_id) - .cast(Integer) - .label( # pylint: disable=not-callable - "first_trial_id_with_defaults" - ), - ) - .where( + self._schema.trial_param.select().with_only_columns( + func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + "first_trial_id_with_defaults"), + ).where( self._schema.trial_param.c.exp_id == self._experiment_id, self._schema.trial_param.c.param_id == "is_defaults", - func.lower( - self._schema.trial_param.c.param_value, type_=String - ).in_(["1", "true"]), - ) - .scalar_subquery() - ), + func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), + ).scalar_subquery() + ) ) ) min_default_trial_row = query_results.fetchone() @@ -185,24 +161,17 @@ def default_tunable_config_id(self) -> Optional[int]: return min_default_trial_row._tuple()[0] # fallback logic - assume minimum trial_id for experiment query_results = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial.select() - .with_only_columns( - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label("first_trial_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .scalar_subquery() - ), + ).scalar_subquery() + ) ) ) min_trial_row = query_results.fetchone() diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index abc5ab27ac..9a1eca2744 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -80,6 +80,7 @@ def __init__(self, engine: Engine): Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), + PrimaryKeyConstraint("exp_id"), ) @@ -94,29 +95,20 @@ def __init__(self, engine: Engine): # Will need to adjust the insert and return values to support this # eventually. Column("weight", Float, nullable=True), + PrimaryKeyConstraint("exp_id", "optimization_target"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ) # A workaround for SQLAlchemy issue with autoincrement in DuckDB: if engine.dialect.name == "duckdb": - seq_config_id = Sequence("seq_config_id") - col_config_id = Column( - "config_id", - Integer, - seq_config_id, - server_default=seq_config_id.next_value(), - nullable=False, - primary_key=True, - ) + seq_config_id = Sequence('seq_config_id') + col_config_id = Column("config_id", Integer, seq_config_id, + server_default=seq_config_id.next_value(), + nullable=False, primary_key=True) else: - col_config_id = Column( - "config_id", - Integer, - nullable=False, - primary_key=True, - autoincrement=True, - ) + col_config_id = Column("config_id", Integer, nullable=False, + primary_key=True, autoincrement=True) self.config = Table( "config", @@ -135,6 +127,7 @@ def __init__(self, engine: Engine): Column("ts_end", DateTime), # Should match the text IDs of `mlos_bench.environments.Status` enum: Column("status", String(self._STATUS_LEN), nullable=False), + PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -148,6 +141,7 @@ def __init__(self, engine: Engine): Column("config_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), + PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -161,10 +155,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), + PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_status = Table( @@ -174,10 +168,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), + UniqueConstraint("exp_id", "trial_id", "ts"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_result = Table( @@ -187,10 +181,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), + PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_telemetry = Table( @@ -201,15 +195,15 @@ def __init__(self, engine: Engine): Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), + UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) _LOG.debug("Schema: %s", self._meta) - def create(self) -> "DbSchema": + def create(self) -> 'DbSchema': """ Create the DB schema. """ diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index a52861d3ad..bde38575bd 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -27,12 +27,10 @@ class SqlStorage(Storage): An implementation of the Storage interface using SQLAlchemy backend. """ - def __init__( - self, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) self._log_sql = self._config.pop("log_sql", False) @@ -49,7 +47,7 @@ def __init__( @property def _schema(self) -> DbSchema: """Lazily create schema upon first access.""" - if not hasattr(self, "_db_schema"): + if not hasattr(self, '_db_schema'): self._db_schema = DbSchema(self._engine).create() if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("DDL statements:\n%s", self._schema) @@ -58,16 +56,13 @@ def _schema(self) -> DbSchema: def __repr__(self) -> str: return self._repr - def experiment( - self, - *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal["min", "max"]], - ) -> Storage.Experiment: + def experiment(self, *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment: return Experiment( engine=self._engine, schema=self._schema, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index d730aef0aa..7ac7958845 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -27,18 +27,15 @@ class Trial(Storage.Trial): Store the results of a single run of the experiment in SQL database. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - config_id: int, - opt_targets: Dict[str, Literal["min", "max"]], - config: Optional[Dict[str, Any]] = None, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + config_id: int, + opt_targets: Dict[str, Literal['min', 'max']], + config: Optional[Dict[str, Any]] = None): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -50,12 +47,9 @@ def __init__( self._engine = engine self._schema = schema - def update( - self, - status: Status, - timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None, - ) -> Optional[Dict[str, Any]]: + def update(self, status: Status, timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) @@ -65,16 +59,13 @@ def update( if status.is_completed(): # Final update of the status and ts_end: cur_status = conn.execute( - self._schema.trial.update() - .where( + self._schema.trial.update().where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] - ), - ) - .values( + ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), + ).values( status=status.name, ts_end=timestamp, ) @@ -82,96 +73,67 @@ def update( if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." - + f" ({cur_status.rowcount} rows)" - ) + f"Failed to update the status of the trial {self} to {status}." + + f" ({cur_status.rowcount} rows)") if metrics: - conn.execute( - self._schema.trial_result.insert().values( - [ - { - "exp_id": self._experiment_id, - "trial_id": self._trial_id, - "metric_id": key, - "metric_value": nullable(str, val), - } - for (key, val) in metrics.items() - ] - ) - ) + conn.execute(self._schema.trial_result.insert().values([ + { + "exp_id": self._experiment_id, + "trial_id": self._trial_id, + "metric_id": key, + "metric_value": nullable(str, val), + } + for (key, val) in metrics.items() + ])) else: # Update of the status and ts_start when starting the trial: assert metrics is None, f"Unexpected metrics for status: {status}" cur_status = conn.execute( - self._schema.trial.update() - .where( + self._schema.trial.update().where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - [ - "RUNNING", - "SUCCEEDED", - "CANCELED", - "FAILED", - "TIMED_OUT", - ] - ), - ) - .values( + ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), + ).values( status=status.name, ts_start=timestamp, ) ) if cur_status.rowcount not in {1, -1}: # Keep the old status and timestamp if already running, but log it. - _LOG.warning( - "Trial %s :: cannot be updated to: %s", self, status - ) + _LOG.warning("Trial %s :: cannot be updated to: %s", self, status) except Exception: conn.rollback() raise return metrics - def update_telemetry( - self, - status: Status, - timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]], - ) -> None: + def update_telemetry(self, status: Status, timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]]) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") - metrics = [ - (utcify_timestamp(ts, origin="local"), key, val) - for (ts, key, val) in metrics - ] + metrics = [(utcify_timestamp(ts, origin="local"), key, val) for (ts, key, val) in metrics] # NOTE: Not every SQLAlchemy dialect supports `Insert.on_conflict_do_nothing()` # and we need to keep `.update_telemetry()` idempotent; hence a loop instead of # a bulk upsert. # See Also: comments in with self._engine.begin() as conn: self._update_status(conn, status, timestamp) - for metric_ts, key, val in metrics: + for (metric_ts, key, val) in metrics: with self._engine.begin() as conn: try: - conn.execute( - self._schema.trial_telemetry.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=metric_ts, - metric_id=key, - metric_value=nullable(str, val), - ) - ) + conn.execute(self._schema.trial_telemetry.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=metric_ts, + metric_id=key, + metric_value=nullable(str, val), + )) except IntegrityError as ex: - _LOG.warning( - "Record already exists: %s :: %s", (metric_ts, key, val), ex - ) + _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) - def _update_status( - self, conn: Connection, status: Status, timestamp: datetime - ) -> None: + def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: """ Insert a new status record into the database. This call is idempotent. @@ -179,18 +141,12 @@ def _update_status( # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") try: - conn.execute( - self._schema.trial_status.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=timestamp, - status=status.name, - ) - ) + conn.execute(self._schema.trial_status.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=timestamp, + status=status.name, + )) except IntegrityError as ex: - _LOG.warning( - "Status with that timestamp already exists: %s %s :: %s", - self, - timestamp, - ex, - ) + _LOG.warning("Status with that timestamp already exists: %s %s :: %s", + self, timestamp, ex) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index b5551bd856..5a6f8a5ee8 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -29,18 +29,15 @@ class TrialSqlData(TrialData): An interface to access the trial data stored in the SQL DB. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - trial_id: int, - config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + trial_id: int, + config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status): super().__init__( experiment_id=experiment_id, trial_id=trial_id, @@ -59,11 +56,8 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ - return TunableConfigSqlData( - engine=self._engine, - schema=self._schema, - tunable_config_id=self._tunable_config_id, - ) + return TunableConfigSqlData(engine=self._engine, schema=self._schema, + tunable_config_id=self._tunable_config_id) @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @@ -74,13 +68,9 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": from mlos_bench.storage.sql.tunable_config_trial_group_data import ( TunableConfigTrialGroupSqlData, ) - - return TunableConfigTrialGroupSqlData( - engine=self._engine, - schema=self._schema, - experiment_id=self._experiment_id, - tunable_config_id=self._tunable_config_id, - ) + return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, + experiment_id=self._experiment_id, + tunable_config_id=self._tunable_config_id) @property def results_df(self) -> pandas.DataFrame: @@ -89,19 +79,16 @@ def results_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_results = conn.execute( - self._schema.trial_result.select() - .where( + self._schema.trial_result.select().where( self._schema.trial_result.c.exp_id == self._experiment_id, - self._schema.trial_result.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_result.c.trial_id == self._trial_id + ).order_by( self._schema.trial_result.c.metric_id, ) ) return pandas.DataFrame( [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], - columns=["metric", "value"], - ) + columns=['metric', 'value']) @property def telemetry_df(self) -> pandas.DataFrame: @@ -110,12 +97,10 @@ def telemetry_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select() - .where( + self._schema.trial_telemetry.select().where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_telemetry.c.trial_id == self._trial_id + ).order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) @@ -123,16 +108,8 @@ def telemetry_df(self) -> pandas.DataFrame: # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [ - ( - utcify_timestamp(row.ts, origin="utc"), - row.metric_id, - row.metric_value, - ) - for row in cur_telemetry.fetchall() - ], - columns=["ts", "metric", "value"], - ) + [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], + columns=['ts', 'metric', 'value']) @property def metadata_df(self) -> pandas.DataFrame: @@ -143,16 +120,13 @@ def metadata_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_params = conn.execute( - self._schema.trial_param.select() - .where( + self._schema.trial_param.select().where( self._schema.trial_param.c.exp_id == self._experiment_id, - self._schema.trial_param.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_param.c.trial_id == self._trial_id + ).order_by( self._schema.trial_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_params.fetchall()], - columns=["parameter", "value"], - ) + columns=['parameter', 'value']) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index 2441f70b9c..e484979790 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -20,7 +20,10 @@ class TunableConfigSqlData(TunableConfigData): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunable_config_id: int): super().__init__(tunable_config_id=tunable_config_id) self._engine = engine self._schema = schema @@ -29,13 +32,12 @@ def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): def config_df(self) -> pandas.DataFrame: with self._engine.connect() as conn: cur_config = conn.execute( - self._schema.config_param.select() - .where(self._schema.config_param.c.config_id == self._tunable_config_id) - .order_by( + self._schema.config_param.select().where( + self._schema.config_param.c.config_id == self._tunable_config_id + ).order_by( self._schema.config_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_config.fetchall()], - columns=["parameter", "value"], - ) + columns=['parameter', 'value']) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index 4c3882c9a0..eb389a5940 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -33,15 +33,12 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None): super().__init__( experiment_id=experiment_id, tunable_config_id=tunable_config_id, @@ -56,28 +53,20 @@ def _get_tunable_config_trial_group_id(self) -> int: """ with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( - self._schema.trial.select() - .with_only_columns( - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label( # pylint: disable=not-callable - "tunable_config_trial_group_id" - ), - ) - .where( + self._schema.trial.select().with_only_columns( + func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + 'tunable_config_trial_group_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.config_id == self._tunable_config_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[ - 0 - ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy @property def tunable_config(self) -> TunableConfigData: @@ -97,12 +86,8 @@ def trials(self) -> Dict[int, "TrialData"]: trials : Dict[int, TrialData] A dictionary of the trials' data, keyed by trial id. """ - return common.get_trials( - self._engine, self._schema, self._experiment_id, self._tunable_config_id - ) + return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) @property def results_df(self) -> pandas.DataFrame: - return common.get_results_df( - self._engine, self._schema, self._experiment_id, self._tunable_config_id - ) + return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id) diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 22e629fc82..220f3d812c 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -13,9 +13,9 @@ from mlos_bench.storage.base_storage import Storage -def from_config( - config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any -) -> Storage: +def from_config(config_file: str, + global_configs: Optional[List[str]] = None, + **kwargs: Any) -> Storage: """ Create a new storage object from JSON5 config file. @@ -36,7 +36,7 @@ def from_config( config_path: List[str] = kwargs.get("config_path", []) config_loader = ConfigPersistenceService({"config_path": config_path}) global_config = {} - for fname in global_configs or []: + for fname in (global_configs or []): config = config_loader.load_config(fname, ConfigSchema.GLOBALS) global_config.update(config) config_path += config.get("config_path", []) diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index 64cc6c953e..a4610da8de 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -25,22 +25,16 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where 'parameter' is a string and 'value' is some TunableValue or None. """ - if dataframe.columns.tolist() == ["metric", "value"]: + if dataframe.columns.tolist() == ['metric', 'value']: dataframe = dataframe.copy() - dataframe.rename(columns={"metric": "parameter"}, inplace=True) - assert dataframe.columns.tolist() == ["parameter", "value"] + dataframe.rename(columns={'metric': 'parameter'}, inplace=True) + assert dataframe.columns.tolist() == ['parameter', 'value'] data = {} - for _, row in dataframe.astype("O").iterrows(): - if not isinstance(row["value"], TunableValueTypeTuple): - raise TypeError( - f"Invalid column type: {type(row['value'])} value: {row['value']}" - ) - assert isinstance(row["parameter"], str) - if row["parameter"] in data: + for _, row in dataframe.astype('O').iterrows(): + if not isinstance(row['value'], TunableValueTypeTuple): + raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}") + assert isinstance(row['parameter'], str) + if row['parameter'] in data: raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe") - data[row["parameter"]] = ( - try_parse_val(row["value"]) - if isinstance(row["value"], str) - else row["value"] - ) + data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value'] return data diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index a3d53a38db..26aa142441 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -29,36 +29,26 @@ None, ] ZONE_INFO: List[Optional[tzinfo]] = [ - nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES + nullable(pytz.timezone, zone_name) + for zone_name in ZONE_NAMES ] # A decorator for tests that require docker. # Use with @requires_docker above a test_...() function. -DOCKER = shutil.which("docker") +DOCKER = shutil.which('docker') if DOCKER: - cmd = run( - "docker builder inspect default || docker buildx inspect default", - shell=True, - check=False, - capture_output=True, - ) + cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) stdout = cmd.stdout.decode() - if cmd.returncode != 0 or not any( - line for line in stdout.splitlines() if "Platform" in line and "linux" in line - ): + if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): debug("Docker is available but missing support for targeting linux platform.") DOCKER = None -requires_docker = pytest.mark.skipif( - not DOCKER, reason="Docker with Linux support is not available on this system." -) +requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') # A decorator for tests that require ssh. # Use with @requires_ssh above a test_...() function. -SSH = shutil.which("ssh") -requires_ssh = pytest.mark.skipif( - not SSH, reason="ssh is not available on this system." -) +SSH = shutil.which('ssh') +requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -141,18 +131,10 @@ def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) - if ( - len(dirs_cmp.left_only) > 0 - or len(dirs_cmp.right_only) > 0 - or len(dirs_cmp.funny_files) > 0 - ): - warning( - f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" - ) + if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: + warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") return False - (_, mismatch, errors) = filecmp.cmpfiles( - dir1, dir2, dirs_cmp.common_files, shallow=False - ) + (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) if len(mismatch) > 0 or len(errors) > 0: warning(f"Found differences in files:\n{mismatch}\n{errors}") return False diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index d6ee5583bb..4d728b4037 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -18,16 +18,12 @@ from importlib.resources import files -BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace( - "\\", "/" -) +BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/") -def locate_config_examples( - root_dir: str, - config_examples_dir: str, - examples_filter: Optional[Callable[[List[str]], List[str]]] = None, -) -> List[str]: +def locate_config_examples(root_dir: str, + config_examples_dir: str, + examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: """Locates all config examples in the given directory. Parameters diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 8e20001926..e1e26d7d8b 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -43,9 +43,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - *locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs - ), + *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -53,9 +51,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: # pragma: no cover +def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -65,9 +61,7 @@ def test_load_cli_config_examples( assert isinstance(config_paths, list) config_paths.reverse() for path in config_paths: - config_loader_service._config_path.insert( - 0, path - ) # pylint: disable=protected-access + config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access # Foreach arg that references another file, see if we can at least load that too. args_to_skip = { @@ -84,39 +78,27 @@ def test_load_cli_config_examples( if arg == "globals": for path in config[arg]: - sub_config = config_loader_service.load_config( - path, ConfigSchema.GLOBALS - ) + sub_config = config_loader_service.load_config(path, ConfigSchema.GLOBALS) assert isinstance(sub_config, dict) elif arg == "environment": - sub_config = config_loader_service.load_config( - config[arg], ConfigSchema.ENVIRONMENT - ) + sub_config = config_loader_service.load_config(config[arg], ConfigSchema.ENVIRONMENT) assert isinstance(sub_config, dict) elif arg == "optimizer": - sub_config = config_loader_service.load_config( - config[arg], ConfigSchema.OPTIMIZER - ) + sub_config = config_loader_service.load_config(config[arg], ConfigSchema.OPTIMIZER) assert isinstance(sub_config, dict) elif arg == "storage": - sub_config = config_loader_service.load_config( - config[arg], ConfigSchema.STORAGE - ) + sub_config = config_loader_service.load_config(config[arg], ConfigSchema.STORAGE) assert isinstance(sub_config, dict) elif arg == "tunable_values": for path in config[arg]: - sub_config = config_loader_service.load_config( - path, ConfigSchema.TUNABLE_VALUES - ) + sub_config = config_loader_service.load_config(path, ConfigSchema.TUNABLE_VALUES) assert isinstance(sub_config, dict) else: raise NotImplementedError(f"Unhandled arg {arg} in config {config_path}") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples_via_launcher( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) @@ -124,38 +106,29 @@ def test_load_cli_config_examples_via_launcher( # Try to load the CLI config by instantiating a launcher. # To do this we need to make sure to give it a few extra paths and globals # to look for for our examples. - cli_args = ( - f"--config {config_path}" - + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" - + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" - + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" - ) - launcher = Launcher( - description=__name__, long_text=config_path, argv=cli_args.split() - ) + cli_args = f"--config {config_path}" + \ + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \ + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \ + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" + launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split()) assert launcher # Check that some parts of that config are loaded. - assert ( - ConfigPersistenceService.BUILTIN_CONFIG_PATH - in launcher.config_loader.config_paths - ) + assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in launcher.config_loader.config_paths if config_paths := config.get("config_path"): assert isinstance(config_paths, list) for path in config_paths: # Note: Checks that the order is maintained are handled in launcher_parse_args.py - assert any( - config_path.endswith(path) - for config_path in launcher.config_loader.config_paths - ), f"Expected {path} to be in {launcher.config_loader.config_paths}" + assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \ + f"Expected {path} to be in {launcher.config_loader.config_paths}" - if "experiment_id" in config: - assert launcher.global_config["experiment_id"] == config["experiment_id"] - if "trial_id" in config: - assert launcher.global_config["trial_id"] == config["trial_id"] + if 'experiment_id' in config: + assert launcher.global_config['experiment_id'] == config['experiment_id'] + if 'trial_id' in config: + assert launcher.global_config['trial_id'] == config['trial_id'] - expected_log_level = logging.getLevelName(config.get("log_level", "INFO")) + expected_log_level = logging.getLevelName(config.get('log_level', "INFO")) if isinstance(expected_log_level, int): expected_log_level = logging.getLevelName(expected_log_level) current_log_level = logging.getLevelName(logging.root.getEffectiveLevel()) @@ -163,7 +136,7 @@ def test_load_cli_config_examples_via_launcher( # TODO: Check that the log_file handler is set correctly. - expected_teardown = config.get("teardown", True) + expected_teardown = config.get('teardown', True) assert launcher.teardown == expected_teardown # Note: Testing of "globals" processing handled in launcher_parse_args_test.py @@ -172,30 +145,22 @@ def test_load_cli_config_examples_via_launcher( # Launcher loaded the expected types as well. assert isinstance(launcher.environment, Environment) - env_config = launcher.config_loader.load_config( - config["environment"], ConfigSchema.ENVIRONMENT - ) + env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: - opt_config = launcher.config_loader.load_config( - config["optimizer"], ConfigSchema.OPTIMIZER - ) + opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: - storage_config = launcher.config_loader.load_config( - config["storage"], ConfigSchema.STORAGE - ) + storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: - scheduler_config = launcher.config_loader.load_config( - config["scheduler"], ConfigSchema.SCHEDULER - ) + scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER) assert check_class_name(launcher.scheduler, scheduler_config["class"]) # TODO: Check that the launcher assigns the tunables values as expected. diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index 2c3932a128..fdcb3370cf 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -22,11 +22,9 @@ @pytest.fixture def config_loader_service() -> ConfigPersistenceService: """Config loader service fixture.""" - return ConfigPersistenceService( - config={ - "config_path": [ - str(files("mlos_bench.tests.config")), - path_join(str(files("mlos_bench.tests.config")), "globals"), - ] - } - ) + return ConfigPersistenceService(config={ + "config_path": [ + str(files("mlos_bench.tests.config")), + path_join(str(files("mlos_bench.tests.config")), "globals"), + ] + }) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 6ee34dbc71..42925a0a5d 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -27,24 +27,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - configs_to_filter = [ - config_path - for config_path in configs_to_filter - if not config_path.endswith("-tunables.jsonc") - ] + configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) for env in envs: @@ -52,15 +44,11 @@ def test_load_environment_config_examples( assert isinstance(env, Environment) -def load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> List[Environment]: +def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. - global_config = config_loader_service.load_config( - "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS - ) - global_config.setdefault("trial_id", 1) # normally populated by Launcher + global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) + global_config.setdefault('trial_id', 1) # normally populated by Launcher # Make sure we have the required services for the envs being used. mock_service_configs = [ @@ -72,34 +60,24 @@ def load_environment_config_examples( "services/remote/mock/mock_auth_service.jsonc", ] - tunable_groups = TunableGroups() # base tunable groups that all others get built on + tunable_groups = TunableGroups() # base tunable groups that all others get built on for mock_service_config_path in mock_service_configs: - mock_service_config = config_loader_service.load_config( - mock_service_config_path, ConfigSchema.SERVICE - ) - config_loader_service.register( - config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service - ).export() - ) + mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) + config_loader_service.register(config_loader_service.build_service( + config=mock_service_config, parent=config_loader_service).export()) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service - ) + config_path, tunable_groups, global_config, service=config_loader_service) return envs -composite_configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" -) +composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) -def test_load_composite_env_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) assert len(envs) == 1 @@ -112,15 +90,11 @@ def test_load_composite_env_config_examples( assert child_env.tunable_params is not None checked_child_env_groups = set() - for child_tunable, child_group in child_env.tunable_params: + for (child_tunable, child_group) in child_env.tunable_params: # Lookup that tunable in the composite env. assert child_tunable in composite_env.tunable_params - (composite_tunable, composite_group) = ( - composite_env.tunable_params.get_tunable(child_tunable) - ) - assert ( - child_tunable is composite_tunable - ) # Check that the tunables are the same object. + (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) + assert child_tunable is composite_tunable # Check that the tunables are the same object. if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) @@ -132,15 +106,10 @@ def test_load_composite_env_config_examples( assert child_tunable.value == old_cat_value assert child_group[child_tunable] == old_cat_value assert composite_env.tunable_params[child_tunable] == old_cat_value - new_cat_value = [ - x for x in child_tunable.categories if x != old_cat_value - ][0] + new_cat_value = [x for x in child_tunable.categories if x != old_cat_value][0] child_tunable.category = new_cat_value assert child_env.tunable_params[child_tunable] == new_cat_value - assert ( - composite_env.tunable_params[child_tunable] - == child_tunable.category - ) + assert composite_env.tunable_params[child_tunable] == child_tunable.category elif child_tunable.is_numerical: old_num_value = child_tunable.numerical_value assert child_tunable.value == old_num_value @@ -148,7 +117,4 @@ def test_load_composite_env_config_examples( assert composite_env.tunable_params[child_tunable] == old_num_value child_tunable.numerical_value += 1 assert child_env.tunable_params[child_tunable] == old_num_value + 1 - assert ( - composite_env.tunable_params[child_tunable] - == child_tunable.numerical_value - ) + assert composite_env.tunable_params[child_tunable] == child_tunable.numerical_value diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index fd53d63788..4d8c93fdff 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -29,9 +29,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs - ), + *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs), ] @@ -39,9 +37,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) -def test_load_globals_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index c504a6d50f..6cb6253dea 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -30,16 +30,12 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_optimizer_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index 54d619caf1..e4264003e1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -34,19 +34,14 @@ def __hash__(self) -> int: # The different type of schema test cases we expect to have. -_SCHEMA_TEST_TYPES = { - x.test_case_type: x - for x in ( - SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), - SchemaTestType( - test_case_type="bad", test_case_subtypes={"invalid", "unhandled"} - ), - ) -} +_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( + SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), + SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), +)} @dataclass -class SchemaTestCaseInfo: +class SchemaTestCaseInfo(): """ Some basic info about a schema test case. """ @@ -66,22 +61,15 @@ def check_schema_dir_layout(test_cases_root: str) -> None: any extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): - if test_case_dir == "README.md": + if test_case_dir == 'README.md': continue if test_case_dir not in _SCHEMA_TEST_TYPES: raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") - for test_case_subdir in os.listdir( - os.path.join(test_cases_root, test_case_dir) - ): - if test_case_subdir == "README.md": + for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): + if test_case_subdir == 'README.md': continue - if ( - test_case_subdir - not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes - ): - raise NotImplementedError( - f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}" - ) + if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: + raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") @dataclass @@ -99,23 +87,15 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: """ Gets a dict of schema test cases from the given root. """ - test_cases = TestCases( - by_path={}, - by_type={x: {} for x in _SCHEMA_TEST_TYPES}, - by_subtype={ - y: {} - for x in _SCHEMA_TEST_TYPES - for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes - }, - ) + test_cases = TestCases(by_path={}, + by_type={x: {} for x in _SCHEMA_TEST_TYPES}, + by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) check_schema_dir_layout(test_cases_root) # Note: we sort the test cases so that we can deterministically test them in parallel. - for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): + for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): for test_case_subtype in schema_test_type.test_case_subtypes: - for test_case_file in locate_config_examples( - test_cases_root, os.path.join(test_case_type, test_case_subtype) - ): - with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: + for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): + with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: try: test_case_info = SchemaTestCaseInfo( config=json5.load(test_case_fh), @@ -123,19 +103,11 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: test_case_type=test_case_type, test_case_subtype=test_case_subtype, ) - test_cases.by_path[test_case_info.test_case_file] = ( - test_case_info - ) - test_cases.by_type[test_case_info.test_case_type][ - test_case_info.test_case_file - ] = test_case_info - test_cases.by_subtype[test_case_info.test_case_subtype][ - test_case_info.test_case_file - ] = test_case_info + test_cases.by_path[test_case_info.test_case_file] = test_case_info + test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info + test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info except Exception as ex: - raise RuntimeError( - "Failed to load test case: " + test_case_file - ) from ex + raise RuntimeError("Failed to load test case: " + test_case_file) from ex assert test_cases assert len(test_cases.by_type["good"]) > 0 @@ -145,9 +117,7 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: return test_cases -def check_test_case_against_schema( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema -) -> None: +def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: """ Checks the given test case against the given schema. @@ -172,9 +142,7 @@ def check_test_case_against_schema( raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") -def check_test_case_config_with_extra_param( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema -) -> None: +def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: """ Checks that the config fails to validate if extra params are present in certain places. """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index a3401baf7f..5dd1666008 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -26,7 +26,6 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: """ @@ -37,9 +36,7 @@ def test_cli_configs_against_schema(test_case_name: str) -> None: # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -47,13 +44,9 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the cli config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index efb9e8019d..dc3cd40425 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,29 +33,23 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. } -expected_environment_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") - if subclass not in NON_CONFIG_ENV_CLASSES -] +expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass + in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') + if subclass not in NON_CONFIG_ENV_CLASSES] assert expected_environment_class_names COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ -expected_leaf_environment_class_names = [ - subclass_name - for subclass_name in expected_environment_class_names - if subclass_name != COMPOSITE_ENV_CLASS_NAME -] +expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names + if subclass_name != COMPOSITE_ENV_CLASS_NAME] # Do the full cross product of all the test cases and all the Environment types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("env_class", expected_environment_class_names) -def test_case_coverage_mlos_bench_environment_type( - test_case_subtype: str, env_class: str -) -> None: +def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench Environment type. """ @@ -63,24 +57,18 @@ def test_case_coverage_mlos_bench_environment_type( if try_resolve_class_name(test_case.config.get("class")) == env_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" - ) + f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: """ Checks that the environment config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -88,9 +76,5 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the environment config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 7cf497695b..5045bf510b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -25,19 +25,14 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: """ Checks that the CLI config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index 6a9d43864f..e9ee653644 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -33,12 +33,9 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_optimizer_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_optimizer_class_names # Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). @@ -52,25 +49,16 @@ # Do the full cross product of all the test cases and all the optimizer types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize( - "mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names -) -def test_case_coverage_mlos_bench_optimizer_type( - test_case_subtype: str, mlos_bench_optimizer_type: str -) -> None: +@pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) +def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench optimizer type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == mlos_bench_optimizer_type - ): + if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}" - ) - + f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}") # Being a little lazy for the moment and relaxing the requirement that we have # a subtype test case for each optimizer and space adapter combo. @@ -79,77 +67,54 @@ def test_case_coverage_mlos_bench_optimizer_type( @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) -def test_case_coverage_mlos_core_optimizer_type( - test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] -) -> None: +def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, + mlos_core_optimizer_type: Optional[OptimizerType]) -> None: """ Checks to see if there is a given type of test case for the given mlos_core optimizer type. """ - optimizer_name = ( - None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name - ) + optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" - ): + if try_resolve_class_name(test_case.config.get("class")) \ + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": optimizer_type = None if test_case.config.get("config"): optimizer_type = test_case.config["config"].get("optimizer_type", None) if optimizer_type == optimizer_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}" - ) + f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}") @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize( - "mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types -) -def test_case_coverage_mlos_core_space_adapter_type( - test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] -) -> None: +@pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) +def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, + mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: """ Checks to see if there is a given type of test case for the given mlos_core space adapter type. """ - space_adapter_name = ( - None - if mlos_core_space_adapter_type is None - else mlos_core_space_adapter_type.name - ) + space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" - ): + if try_resolve_class_name(test_case.config.get("class")) \ + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": space_adapter_type = None if test_case.config.get("config"): - space_adapter_type = test_case.config["config"].get( - "space_adapter_type", None - ) + space_adapter_type = test_case.config["config"].get("space_adapter_type", None) if space_adapter_type == space_adapter_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}" - ) + f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: """ Checks that the optimizer config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -157,9 +122,5 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the optimizer config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 279c171a90..8fccba8bc7 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -30,37 +30,25 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_scheduler_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_scheduler_class_names # Do the full cross product of all the test cases and all the scheduler types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize( - "mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names -) -def test_case_coverage_mlos_bench_scheduler_type( - test_case_subtype: str, mlos_bench_scheduler_type: str -) -> None: +@pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) +def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench scheduler type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == mlos_bench_scheduler_type - ): + if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}" - ) - + f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}") # Now we actually perform all of those validation tests. @@ -70,12 +58,8 @@ def test_scheduler_configs_against_schema(test_case_name: str) -> None: """ Checks that the scheduler config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -83,12 +67,8 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the scheduler config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index f7daf3f422..64c6fccccd 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -38,33 +38,30 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_SERVICE_CLASSES = { - ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. - AzureDeploymentService, # ABCMeta abstract base class - SshService, # ABCMeta abstract base class + ConfigPersistenceService, # configured thru the launcher cli args + TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + AzureDeploymentService, # ABCMeta abstract base class + SshService, # ABCMeta abstract base class } -expected_service_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") - if subclass not in NON_CONFIG_SERVICE_CLASSES -] +expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass + in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') + if subclass not in NON_CONFIG_SERVICE_CLASSES] assert expected_service_class_names # Do the full cross product of all the test cases and all the Service types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("service_class", expected_service_class_names) -def test_case_coverage_mlos_bench_service_type( - test_case_subtype: str, service_class: str -) -> None: +def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench Service type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] if not isinstance(test_case.config, dict): - continue # type: ignore[unreachable] + continue # type: ignore[unreachable] if "class" not in test_case.config: config_list = test_case.config["services"] else: @@ -73,24 +70,18 @@ def test_case_coverage_mlos_bench_service_type( if try_resolve_class_name(config.get("class")) == service_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for service class {service_class}" - ) + f"Missing test case for subtype {test_case_subtype} for service class {service_class}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: """ Checks that the service config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -98,9 +89,5 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the service config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 640ae450f3..9b362b5e0d 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -28,52 +28,36 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_storage_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_storage_class_names # Do the full cross product of all the test cases and all the storage types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) -@pytest.mark.parametrize( - "mlos_bench_storage_type", expected_mlos_bench_storage_class_names -) -def test_case_coverage_mlos_bench_storage_type( - test_case_subtype: str, mlos_bench_storage_type: str -) -> None: +@pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) +def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench storage type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == mlos_bench_storage_type - ): + if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}" - ) + f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: """ Checks that the storage config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) @@ -81,15 +65,9 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the storage config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) - - -if __name__ == "__main__": - pytest.main( - [__file__, "-n0"], - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + + +if __name__ == '__main__': + pytest.main([__file__, '-n0'],) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index cf0223d006..a6d0de9313 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -25,15 +25,10 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: """ Checks that the tunable params config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS - ) - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 04d2f4c709..d871eaa212 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -25,19 +25,14 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: """ Checks that the tunable values config validates against the schema. """ - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, # so adding/removing params doesn't invalidate it against all of the config types. - check_test_case_against_schema( - TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 8431251098..32034eb11c 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -25,27 +25,19 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - def predicate(config_path: str) -> bool: - arm_template = config_path.find( - "services/remote/azure/arm-templates/" - ) >= 0 and config_path.endswith(".jsonc") + arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc") setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0 return not (arm_template or setup_rg_scripts) - return [config_path for config_path in configs_to_filter if predicate(config_path)] -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_service_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) # Make an instance of the class based on the config. diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index d1d39ec4f5..2f9773a9b0 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -29,16 +29,12 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_storage_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 28c83f453c..58359eb983 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -42,7 +42,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score"], }, - tunables=tunable_groups, + tunables=tunable_groups ) @@ -59,7 +59,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score", "other_score"], }, - tunables=tunable_groups, + tunables=tunable_groups ) @@ -82,9 +82,7 @@ def docker_compose_file(pytestconfig: pytest.Config) -> List[str]: """ _ = pytestconfig # unused return [ - os.path.join( - os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml" - ), + os.path.join(os.path.dirname(__file__), "services", "remote", "ssh", "docker-compose.yml"), # Add additional configs as necessary here. ] @@ -105,9 +103,7 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") -def docker_services_lock( - shared_temp_dir: str, short_testrun_uid: str -) -> InterProcessReaderWriterLock: +def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker services. @@ -117,15 +113,11 @@ def docker_services_lock( A lock to ensure that setup/teardown operations don't happen while a worker is using the docker services. """ - return InterProcessReaderWriterLock( - f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock" - ) + return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock") @pytest.fixture(scope="session") -def docker_setup_teardown_lock( - shared_temp_dir: str, short_testrun_uid: str -) -> InterProcessLock: +def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessLock: """ Gets a pytest session lock between xdist workers for the docker setup/teardown operations. @@ -134,9 +126,7 @@ def docker_setup_teardown_lock( ------ A lock to ensure that only one worker is doing setup/teardown at a time. """ - return InterProcessLock( - f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock" - ) + return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock") @pytest.fixture(scope="session") diff --git a/mlos_bench/mlos_bench/tests/dict_templater_test.py b/mlos_bench/mlos_bench/tests/dict_templater_test.py index 6604656c9a..63219d9246 100644 --- a/mlos_bench/mlos_bench/tests/dict_templater_test.py +++ b/mlos_bench/mlos_bench/tests/dict_templater_test.py @@ -124,9 +124,7 @@ def test_from_extras_expansion(source_template_dict: Dict[str, Any]) -> None: "extra_str": "str-from-extras", "string": "shouldn't be used", } - results = DictTemplater(source_template_dict).expand_vars( - extra_source_dict=extra_source_dict - ) + results = DictTemplater(source_template_dict).expand_vars(extra_source_dict=extra_source_dict) assert results == { "extra_str-ref": f"{extra_source_dict['extra_str']}-ref", "str": "string", diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index 8218577986..ac0b942167 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -16,13 +16,11 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def check_env_success( - env: Environment, - tunable_groups: TunableGroups, - expected_results: Dict[str, TunableValue], - expected_telemetry: List[Tuple[datetime, str, Any]], - global_config: Optional[dict] = None, -) -> None: +def check_env_success(env: Environment, + tunable_groups: TunableGroups, + expected_results: Dict[str, TunableValue], + expected_telemetry: List[Tuple[datetime, str, Any]], + global_config: Optional[dict] = None) -> None: """ Set up an environment and run a test experiment there. @@ -52,7 +50,7 @@ def check_env_success( assert telemetry == pytest.approx(expected_telemetry, nan_ok=True) env_context.teardown() - assert not env_context._is_ready # pylint: disable=protected-access + assert not env_context._is_ready # pylint: disable=protected-access def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 863e5aaa80..8afb8e5cda 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -29,8 +29,8 @@ def test_expand_groups() -> None: Check the dollar variable expansion for tunable groups. """ assert Environment._expand_groups( - ["begin", "$list", "$empty", "$str", "end"], _GROUPS - ) == ["begin", "c", "d", "efg", "end"] + ["begin", "$list", "$empty", "$str", "end"], + _GROUPS) == ["begin", "c", "d", "efg", "end"] def test_expand_groups_empty_input() -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index f8f6d28afe..6497eb6985 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -34,32 +34,26 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: { "name": "Env 2 :: tmp_other_2", "class": "mlos_bench.environments.mock_env.MockEnv", - "include_services": [ - "services/local/mock/mock_local_exec_service_2.jsonc" - ], + "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"], }, { "name": "Env 3 :: tmp_other_3", "class": "mlos_bench.environments.mock_env.MockEnv", - "include_services": [ - "services/local/mock/mock_local_exec_service_3.jsonc" - ], - }, + "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], + } ] }, tunables=tunable_groups, service=LocalExecService( - config={"temp_dir": "_test_tmp_global"}, - parent=ConfigPersistenceService( - { - "config_path": [ - path_join( - os.path.dirname(__file__), "../config", abs_path=True - ), - ] - } - ), - ), + config={ + "temp_dir": "_test_tmp_global" + }, + parent=ConfigPersistenceService({ + "config_path": [ + path_join(os.path.dirname(__file__), "../config", abs_path=True), + ] + }) + ) ) @@ -67,11 +61,7 @@ def test_composite_services(composite_env: CompositeEnv) -> None: """ Check that each environment gets its own instance of the services. """ - for i, path in ( - (0, "_test_tmp_global"), - (1, "_test_tmp_other_2"), - (2, "_test_tmp_other_3"), - ): + for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") with service.temp_dir_context() as temp_dir: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 184aad778d..742eaf3c79 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -28,7 +28,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", "someConst": "root", - "global_param": "default", + "global_param": "default" }, "children": [ { @@ -43,7 +43,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "someConst", "global_param"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, { "name": "Mock Server Environment 2", @@ -53,12 +53,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vmName": "$vm_server_name", "EnvId": 2, - "global_param": "local", + "global_param": "local" }, "required_args": ["vmName"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, { "name": "Mock Control Environment 3", @@ -72,13 +72,15 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "vm_server_name", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, - }, - ], + } + } + ] }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={"global_param": "global_value"}, + global_config={ + "global_param": "global_value" + } ) @@ -88,65 +90,61 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value", # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", - "vm_server_name": "Mock Server VM", + "vm_server_name": "Mock Server VM" # "global_param": "global_value" # not required, so not picked from the global_config } -def test_composite_env_setup( - composite_env: CompositeEnv, tunable_groups: TunableGroups -) -> None: +def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign( - { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + }) with composite_env as env_context: assert env_context.setup(tunable_groups) assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value", # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "mwait", # tunable_params from the parent "vm_client_name": "Mock Client VM", "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config @@ -165,7 +163,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", - "someConst": "root", + "someConst": "root" }, "children": [ { @@ -177,12 +175,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vmName": "$vm_client_name", "EnvId": 1, }, - "required_args": [ - "vmName", - "EnvId", - "someConst", - "vm_server_name", - ], + "required_args": ["vmName", "EnvId", "someConst", "vm_server_name"], "children": [ { "name": "Mock Client Environment 1", @@ -198,11 +191,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "EnvId", "someConst", "vm_server_name", - "global_param", + "global_param" ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, # ... ], @@ -224,24 +217,23 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "class": "mlos_bench.environments.mock_env.MockEnv", "config": { "tunable_params": ["boot"], - "required_args": [ - "vmName", - "EnvId", - "vm_client_name", - ], + "required_args": ["vmName", "EnvId", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, # ... ], }, }, - ], + + ] }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={"global_param": "global_value"}, + global_config={ + "global_param": "global_value" + } ) @@ -252,56 +244,52 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None """ assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_nested_composite_env_setup( - nested_composite_env: CompositeEnv, tunable_groups: TunableGroups -) -> None: +def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign( - { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + }) with nested_composite_env as env_context: assert env_context.setup(tunable_groups) assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", } diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index bf3407b506..7395aa3e15 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -18,7 +18,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision"]}, - tunables=tunable_groups, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -32,7 +32,7 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, - tunables=tunable_groups, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -55,7 +55,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups, + tunables=tunable_groups ) expected_params = { "vmSize": "Standard_B4ms", @@ -80,7 +80,11 @@ def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: """ Make sure that no tunable groups are available to the environment by default. """ - env = MockEnv(name="Test Env", config={}, tunables=tunable_groups) + env = MockEnv( + name="Test Env", + config={}, + tunables=tunable_groups + ) assert env.tunable_params.get_param_values() == {} @@ -90,7 +94,9 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: when explicitly specifying an empty list of tunable_params. """ env = MockEnv( - name="Test Env", config={"tunable_params": []}, tunables=tunable_groups + name="Test Env", + config={"tunable_params": []}, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == {} @@ -108,7 +114,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == {} @@ -131,7 +137,9 @@ def test_loader_level_include() -> None: env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", "name": "Test Env", - "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"], + "include_tunables": [ + "environments/os/linux/boot/linux-boot-tunables.jsonc" + ], "config": { "tunable_params": ["linux-kernel-boot"], "const_args": { @@ -140,14 +148,12 @@ def test_loader_level_include() -> None: }, }, } - loader = ConfigPersistenceService( - { - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - } - ) + loader = ConfigPersistenceService({ + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + }) env = loader.build_environment(config=env_json, tunables=TunableGroups()) expected_params = { "align_va_addr": "on", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index c68d2fa7b8..5d8fc32c6b 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -32,20 +32,14 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L env : LocalEnv A new instance of the local environment. """ - return LocalEnv( - name="TestLocalEnv", - config=config, - tunables=tunable_groups, - service=LocalExecService(parent=ConfigPersistenceService()), - ) + return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups, + service=LocalExecService(parent=ConfigPersistenceService())) -def create_composite_local_env( - tunable_groups: TunableGroups, - global_config: Dict[str, Any], - params: Dict[str, Any], - local_configs: List[Dict[str, Any]], -) -> CompositeEnv: +def create_composite_local_env(tunable_groups: TunableGroups, + global_config: Dict[str, Any], + params: Dict[str, Any], + local_configs: List[Dict[str, Any]]) -> CompositeEnv: """ Create a CompositeEnv with several LocalEnv instances. @@ -76,7 +70,7 @@ def create_composite_local_env( "config": config, } for (i, config) in enumerate(local_configs) - ], + ] }, tunables=tunable_groups, global_config=global_config, diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index c38c6bc584..9bcb7aa218 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -26,9 +26,7 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_composite_env( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Produce benchmark and telemetry data in TWO local environments and combine the results. @@ -45,7 +43,7 @@ def test_composite_env( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "") + (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "") env = create_composite_local_env( tunable_groups=tunable_groups, @@ -69,8 +67,8 @@ def test_composite_env( "required_args": ["errors", "reads"], "shell_env_params": [ "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads", # const_args overridden by the global config + "errors", # Comes from the parent const_args + "reads" # const_args overridden by the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -92,9 +90,9 @@ def test_composite_env( }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes", # Comes straight from the global config + "throughput", # const_args overridden by the composite env + "score", # Comes from the local const_args + "writes" # Comes straight from the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -108,13 +106,12 @@ def test_composite_env( ], "read_results_file": "output.csv", "read_telemetry_file": "telemetry.csv", - }, - ], + } + ] ) check_env_success( - env, - tunable_groups, + env, tunable_groups, expected_results={ "latency": 4.2, "throughput": 768.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index bdcd9f885f..20854b9f9e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -17,23 +17,19 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", - ], - "results_stdout_pattern": r"(\w+),([0-9.]+)", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", + ], + "results_stdout_pattern": r"(\w+),([0-9.]+)", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -49,23 +45,19 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern - ], - "results_stdout_pattern": r"^(\w+),([0-9.]+)$", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern + ], + "results_stdout_pattern": r"^(\w+),([0-9.]+)$", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -80,28 +72,24 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to *BOTH* stdout and a file and extract the results from both. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'stdout-msg,string'", - "echo '-------------------'", # Should be ignored - "echo 'metric,value' > output.csv", - "echo 'extra1,333' >> output.csv", - "echo 'extra2,444' >> output.csv", - "echo 'file-msg,string' >> output.csv", - ], - "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'stdout-msg,string'", + "echo '-------------------'", # Should be ignored + "echo 'metric,value' > output.csv", + "echo 'extra1,333' >> output.csv", + "echo 'extra2,444' >> output.csv", + "echo 'file-msg,string' >> output.csv", + ], + "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index f620165de8..35bdb39486 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -25,9 +25,7 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ @@ -39,29 +37,25 @@ def test_local_env_telemetry( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,4.1' >> output.csv", - "echo 'throughput,512' >> output.csv", - "echo 'score,0.95' >> output.csv", - "echo '-------------------'", # This output does not go anywhere - "echo 'timestamp,metric,value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_results_file": "output.csv", - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,4.1' >> output.csv", + "echo 'throughput,512' >> output.csv", + "echo 'score,0.95' >> output.csv", + "echo '-------------------'", # This output does not go anywhere + "echo 'timestamp,metric,value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_results_file": "output.csv", + "read_telemetry_file": "telemetry.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 4.1, "throughput": 512.0, @@ -78,9 +72,7 @@ def test_local_env_telemetry( # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_no_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Read the telemetry data with no header. """ @@ -92,22 +84,18 @@ def test_local_env_telemetry_no_header( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - f"echo {time_str1},cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + f"echo {time_str1},cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={}, expected_telemetry=[ (ts1.astimezone(UTC), "cpu_load", 0.65), @@ -118,13 +106,9 @@ def test_local_env_telemetry_no_header( ) -@pytest.mark.filterwarnings( - "ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0" -) # pylint: disable=line-too-long # noqa +@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_wrong_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Read the telemetry data with incorrect header. """ @@ -136,20 +120,17 @@ def test_local_env_telemetry_wrong_header( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: the data is correct, but the header has unexpected column names - "echo 'ts,metric_name,metric_value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: the data is correct, but the header has unexpected column names + "echo 'ts,metric_name,metric_value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) @@ -167,19 +148,16 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: too many columns - f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: too many columns + f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) @@ -188,18 +166,15 @@ def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: """ Fail when the telemetry data has wrong format. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: field 1 must be a timestamp - "echo 1,cpu_load,0.65 > telemetry.csv", - "echo 2,mem_usage,10240 >> telemetry.csv", - "echo 3,cpu_load,0.8 >> telemetry.csv", - "echo 4,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: field 1 must be a timestamp + "echo 1,cpu_load,0.65 > telemetry.csv", + "echo 2,mem_usage,10240 >> telemetry.csv", + "echo 3,cpu_load,0.8 >> telemetry.csv", + "echo 4,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 2b51ae1f0e..6cb4fd4f7e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -16,22 +16,18 @@ def test_local_env(tunable_groups: TunableGroups) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,10' >> output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,10' >> output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 10.0, "throughput": 66.0, @@ -45,7 +41,9 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ Basic check that context support for Service mixins are handled when environment contexts are entered. """ - local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) + local_env = create_local_env(tunable_groups, { + "run": ["echo NA"] + }) # pylint: disable=protected-access assert local_env._service assert not local_env._service._in_context @@ -53,10 +51,10 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: with local_env as env_context: assert env_context._in_context assert local_env._service._in_context - assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) + assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) assert all(svc._in_context for svc in local_env._service._service_contexts) assert all(svc._in_context for svc in local_env._service._services) - assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) + assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) assert not local_env._service._service_contexts assert not any(svc._in_context for svc in local_env._service._services) @@ -65,18 +63,15 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: """ Fail if the results are not in the expected format. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - # No header - "echo 'latency,10' > output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # No header + "echo 'latency,10' > output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) with local_env as env_context: assert env_context.setup(tunable_groups) @@ -88,20 +83,16 @@ def test_local_env_wide(tunable_groups: TunableGroups) -> None: """ Produce benchmark data in wide format and read it. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'latency,throughput,score' > output.csv", - "echo '10,66,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'latency,throughput,score' > output.csv", + "echo '10,66,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 10, "throughput": 66, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index 52e15be076..c16eac4459 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -14,36 +14,31 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def _run_local_env( - tunable_groups: TunableGroups, shell_subcmd: str, expected: dict -) -> None: +def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None: """ Check that LocalEnv can set shell environment variables. """ - local_env = create_local_env( - tunable_groups, - { - "const_args": { - "const_arg": 111, # Passed into "shell_env_params" - "other_arg": 222, # NOT passed into "shell_env_params" - }, - "tunable_params": ["kernel"], - "shell_env_params": [ - "const_arg", # From "const_arg" - "kernel_sched_latency_ns", # From "tunable_params" - ], - "run": [ - "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", - f"echo {shell_subcmd} >> output.csv", - ], - "read_results_file": "output.csv", + local_env = create_local_env(tunable_groups, { + "const_args": { + "const_arg": 111, # Passed into "shell_env_params" + "other_arg": 222, # NOT passed into "shell_env_params" }, - ) + "tunable_params": ["kernel"], + "shell_env_params": [ + "const_arg", # From "const_arg" + "kernel_sched_latency_ns", # From "tunable_params" + ], + "run": [ + "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", + f"echo {shell_subcmd} >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success(local_env, tunable_groups, expected, []) -@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only") +@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables in sh-like shell. @@ -52,15 +47,15 @@ def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", expected={ - "const_arg": 111, # From "const_args" - "other_arg": float("NaN"), # Not included in "shell_env_params" - "unknown_arg": float("NaN"), # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - }, + "const_arg": 111, # From "const_args" + "other_arg": float("NaN"), # Not included in "shell_env_params" + "unknown_arg": float("NaN"), # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + } ) -@pytest.mark.skipif(sys.platform != "win32", reason="Windows only") +@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables on Windows / cmd shell. @@ -69,9 +64,9 @@ def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%", expected={ - "const_arg": 111, # From "const_args" - "other_arg": r"%other_arg%", # Not included in "shell_env_params" - "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - }, + "const_arg": 111, # From "const_args" + "other_arg": r"%other_arg%", # Not included in "shell_env_params" + "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + } ) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 25e75cf748..8bce053f7b 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -25,14 +25,13 @@ def mock_fileshare_service() -> MockFileShareService: """ return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, - parent=LocalExecService(parent=ConfigPersistenceService()), + parent=LocalExecService(parent=ConfigPersistenceService()) ) @pytest.fixture -def local_fileshare_env( - tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService -) -> LocalFileShareEnv: +def local_fileshare_env(tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: """ Create a LocalFileShareEnv instance. """ @@ -41,12 +40,12 @@ def local_fileshare_env( config={ "const_args": { "experiment_id": "EXP_ID", # Passed into "shell_env_params" - "trial_id": 222, # NOT passed into "shell_env_params" + "trial_id": 222, # NOT passed into "shell_env_params" }, "tunable_params": ["boot"], "shell_env_params": [ - "trial_id", # From "const_arg" - "idle", # From "tunable_params", == "halt" + "trial_id", # From "const_arg" + "idle", # From "tunable_params", == "halt" ], "upload": [ { @@ -58,7 +57,9 @@ def local_fileshare_env( "to": "$experiment_id/$trial_id/input/data_$idle.csv", }, ], - "run": ["echo No-op run"], + "run": [ + "echo No-op run" + ], "download": [ { "from": "$experiment_id/$trial_id/$idle/data.csv", @@ -72,11 +73,9 @@ def local_fileshare_env( return env -def test_local_fileshare_env( - tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService, - local_fileshare_env: LocalFileShareEnv, -) -> None: +def test_local_fileshare_env(tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, + local_fileshare_env: LocalFileShareEnv) -> None: """ Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the upload and download sections of the config. diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index 427fe90706..608edbf9ef 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -28,9 +28,7 @@ def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> N assert data["score"] == pytest.approx(72.92, 0.01) -def test_mock_env_no_noise( - mock_env_no_noise: MockEnv, tunable_groups: TunableGroups -) -> None: +def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None: """ Check the default values of the mock environment. """ @@ -44,33 +42,20 @@ def test_mock_env_no_noise( assert data["score"] == pytest.approx(75.0, 0.01) -@pytest.mark.parametrize( - ("tunable_values", "expected_score"), - [ - ( - { - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000, - }, - 66.4, - ), - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000, - }, - 74.06, - ), - ], -) -def test_mock_env_assign( - mock_env: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, - expected_score: float, -) -> None: +@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000 + }, 66.4), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000 + }, 74.06), +]) +def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the mock environment after the assignment. """ @@ -83,33 +68,21 @@ def test_mock_env_assign( assert data["score"] == pytest.approx(expected_score, 0.01) -@pytest.mark.parametrize( - ("tunable_values", "expected_score"), - [ - ( - { - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000, - }, - 67.5, - ), - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000, - }, - 75.1, - ), - ], -) -def test_mock_env_no_noise_assign( - mock_env_no_noise: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, - expected_score: float, -) -> None: +@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000 + }, 67.5), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000 + }, 75.1), +]) +def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the noiseless mock environment after the assignment. """ diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 6d47d1fc61..878531d799 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -38,31 +38,25 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: "ssh_priv_key_path": ssh_test_server.id_rsa_path, } - service = ConfigPersistenceService( - config={"config_path": [str(files("mlos_bench.tests.config"))]} - ) + service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") - env = service.load_environment( - config_path, TunableGroups(), global_config=global_config, service=service - ) + env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) check_env_success( - env, - env.tunable_params, + env, env.tunable_params, expected_results={ "hostname": ssh_test_server.service_name, "username": ssh_test_server.username, "score": 0.9, - "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" + "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" "test_param": "unset", "FOO": "unset", "ssh_username": "unset", }, expected_telemetry=[], ) - assert not os.path.exists( - os.path.join(os.getcwd(), "output-downloaded.csv") - ), "output-downloaded.csv should have been cleaned up by temp_dir context" + assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ + "output-downloaded.csv should have been cleaned up by temp_dir context" if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index fc00e5cb65..377bc940a0 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -40,21 +40,16 @@ def __enter__(self) -> None: self.EVENT_LOOP_CONTEXT.enter() self._in_context = True - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: assert self._in_context self.EVENT_LOOP_CONTEXT.exit() self._in_context = False return False -@pytest.mark.filterwarnings( - "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" -) +@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") def test_event_loop_context() -> None: """Test event loop context background thread setup/cleanup handling.""" # pylint: disable=protected-access,too-many-statements @@ -90,20 +85,14 @@ def test_event_loop_context() -> None: with event_loop_caller_instance_2: assert event_loop_caller_instance_2._in_context assert event_loop_caller_instance_1._in_context - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 - ) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 # We should only get one thread for all instances. - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread - ) - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop - ) assert not event_loop_caller_instance_2._in_context @@ -115,43 +104,31 @@ def test_event_loop_context() -> None: assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == "foo" + assert future.result(timeout=0.2) == 'foo' assert 0.1 <= time.time() - start <= 0.2 # Once we exit the last context, the background thread should be stopped # and unusable for running co-routines. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None - ) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() # Check that the event loop has no more tasks. - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready") + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready') # Windows ProactorEventLoopPolicy adds a dummy task. - if sys.platform == "win32" and isinstance( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop - ): + if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled') assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 - with pytest.raises( - AssertionError - ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) - raise ValueError( - f"Future should not have been available to wait on {future.result()}" - ) + with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + raise ValueError(f"Future should not have been available to wait on {future.result()}") # Test that when re-entering the context we have the same event loop. with event_loop_caller_instance_1: @@ -161,14 +138,12 @@ def test_event_loop_context() -> None: # Test running again. start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == "foo" + assert future.result(timeout=0.2) == 'foo' assert 0.1 <= time.time() - start <= 0.2 -if __name__ == "__main__": +if __name__ == '__main__': # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 25abf659ce..90aa7e08f7 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -14,33 +14,19 @@ @pytest.mark.parametrize( - ("argv", "expected_score"), - [ - ( - [ - "--config", - "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", - "--trial_config_repeat_count", - "5", - "--mock_env_seed", - "-1", # Deterministic Mock Environment. - ], - 67.40329, - ), - ( - [ - "--config", - "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", - "--trial_config_repeat_count", - "3", - "--max_suggestions", - "3", - "--mock_env_seed", - "42", # Noisy Mock Environment. - ], - 64.53897, - ), - ], + ("argv", "expected_score"), [ + ([ + "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", + "--trial_config_repeat_count", "5", + "--mock_env_seed", "-1", # Deterministic Mock Environment. + ], 67.40329), + ([ + "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", + "--trial_config_repeat_count", "3", + "--max_suggestions", "3", + "--mock_env_seed", "42", # Noisy Mock Environment. + ], 64.53897), + ] ) def test_main_bench(argv: List[str], expected_score: float) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 39a9ae1a9b..634050d099 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -48,8 +48,8 @@ def config_paths() -> List[str]: """ return [ path_join(os.getcwd(), abs_path=True), - str(files("mlos_bench.config")), - str(files("mlos_bench.tests.config")), + str(files('mlos_bench.config')), + str(files('mlos_bench.tests.config')), ] @@ -64,23 +64,20 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == "win32": + if sys.platform == 'win32': # Some env tweaks for platform compatibility. - environ["USER"] = environ["USERNAME"] + environ['USER'] = environ['USERNAME'] # This is part of the minimal required args by the Launcher. - env_conf_path = "environments/mock/mock_env.jsonc" - cli_args = ( - "--config-paths " - + " ".join(config_paths) - + " --service services/remote/mock/mock_auth_service.jsonc" - + " --service services/remote/mock/mock_remote_exec_service.jsonc" - + " --scheduler schedulers/sync_scheduler.jsonc" - + f" --environment {env_conf_path}" - + " --globals globals/global_test_config.jsonc" - + " --globals globals/global_test_extra_config.jsonc" - " --test_global_value_2 from-args" - ) + env_conf_path = 'environments/mock/mock_env.jsonc' + cli_args = '--config-paths ' + ' '.join(config_paths) + \ + ' --service services/remote/mock/mock_auth_service.jsonc' + \ + ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ + ' --scheduler schedulers/sync_scheduler.jsonc' + \ + f' --environment {env_conf_path}' + \ + ' --globals globals/global_test_config.jsonc' + \ + ' --globals globals/global_test_extra_config.jsonc' \ + ' --test_global_value_2 from-args' launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -88,35 +85,30 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsLocalExec) assert isinstance(launcher.service, SupportsRemoteExec) # Check that the first --globals file is loaded and $var expansion is handled. - assert launcher.global_config["experiment_id"] == "MockExperiment" - assert launcher.global_config["testVmName"] == "MockExperiment-vm" + assert launcher.global_config['experiment_id'] == 'MockExperiment' + assert launcher.global_config['testVmName'] == 'MockExperiment-vm' # Check that secondary expansion also works. - assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" + assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet' # Check that the second --globals file is loaded. - assert launcher.global_config["test_global_value"] == "from-file" + assert launcher.global_config['test_global_value'] == 'from-file' # Check overriding values in a file from the command line. - assert launcher.global_config["test_global_value_2"] == "from-args" + assert launcher.global_config['test_global_value_2'] == 'from-args' # Check that we can expand a $var in a config file that references an environment variable. - assert path_join( - launcher.global_config["pathVarWithEnvVarRef"], abs_path=True - ) == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ + == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' assert launcher.teardown # Check that the environment that got loaded looks to be of the right type. - env_config = launcher.config_loader.load_config( - env_conf_path, ConfigSchema.ENVIRONMENT - ) - assert check_class_name(launcher.environment, env_config["class"]) + env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT) + assert check_class_name(launcher.environment, env_config['class']) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. assert launcher.optimizer.tunable_params.is_defaults() - assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer + assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) - assert ( - launcher.scheduler._trial_config_repeat_count == 3 - ) # pylint: disable=protected-access + assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access assert launcher.scheduler._max_trials == -1 # pylint: disable=protected-access @@ -130,25 +122,23 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == "win32": + if sys.platform == 'win32': # Some env tweaks for platform compatibility. - environ["USER"] = environ["USERNAME"] - - config_file = "cli/test-cli-config.jsonc" - globals_file = "globals/global_test_config.jsonc" - cli_args = ( - " ".join([f"--config-path {config_path}" for config_path in config_paths]) - + f" --config {config_file}" - + " --service services/remote/mock/mock_auth_service.jsonc" - + " --service services/remote/mock/mock_remote_exec_service.jsonc" - + f" --globals {globals_file}" - + " --experiment_id MockeryExperiment" - + " --no-teardown" - + " --random-init" - + " --random-seed 1234" - + " --trial-config-repeat-count 5" - + " --max_trials 200" - ) + environ['USER'] = environ['USERNAME'] + + config_file = 'cli/test-cli-config.jsonc' + globals_file = 'globals/global_test_config.jsonc' + cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \ + f' --config {config_file}' + \ + ' --service services/remote/mock/mock_auth_service.jsonc' + \ + ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ + f' --globals {globals_file}' + \ + ' --experiment_id MockeryExperiment' + \ + ' --no-teardown' + \ + ' --random-init' + \ + ' --random-seed 1234' + \ + ' --trial-config-repeat-count 5' + \ + ' --max_trials 200' launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -158,48 +148,35 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsRemoteExec) # Check that the --globals file is loaded and $var expansion is handled # using the value provided on the CLI. - assert launcher.global_config["experiment_id"] == "MockeryExperiment" - assert launcher.global_config["testVmName"] == "MockeryExperiment-vm" + assert launcher.global_config['experiment_id'] == 'MockeryExperiment' + assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm' # Check that secondary expansion also works. - assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet" + assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet' # Check that we can expand a $var in a config file that references an environment variable. - assert path_join( - launcher.global_config["pathVarWithEnvVarRef"], abs_path=True - ) == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ + == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' assert not launcher.teardown config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI) - assert launcher.config_loader.config_paths == [ - path_join(path, abs_path=True) for path in config_paths + config["config_path"] - ] + assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']] # Check that the environment that got loaded looks to be of the right type. - env_config_file = config["environment"] - env_config = launcher.config_loader.load_config( - env_config_file, ConfigSchema.ENVIRONMENT - ) - assert check_class_name(launcher.environment, env_config["class"]) + env_config_file = config['environment'] + env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) + assert check_class_name(launcher.environment, env_config['class']) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) - opt_config_file = config["optimizer"] - opt_config = launcher.config_loader.load_config( - opt_config_file, ConfigSchema.OPTIMIZER - ) - globals_file_config = launcher.config_loader.load_config( - globals_file, ConfigSchema.GLOBALS - ) + opt_config_file = config['optimizer'] + opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER) + globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS) # The actual global_config gets overwritten as a part of processing, so to test # this we read the original value out of the source files. - orig_max_iters = globals_file_config.get( - "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100) - ) - assert ( - launcher.optimizer.max_iterations - == orig_max_iters - == launcher.global_config["max_suggestions"] - ) + orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100)) + assert launcher.optimizer.max_iterations \ + == orig_max_iters \ + == launcher.global_config['max_suggestions'] # Check that the optimizer got initialized with random values instead of the defaults. # Note: the environment doesn't get updated until suggest() is called to @@ -212,18 +189,16 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # Check that CLI parameter overrides JSON config: assert isinstance(launcher.scheduler, SyncScheduler) - assert ( - launcher.scheduler._trial_config_repeat_count == 5 - ) # pylint: disable=protected-access + assert launcher.scheduler._trial_config_repeat_count == 5 # pylint: disable=protected-access assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. - assert config["random_seed"] == 42 + assert config['random_seed'] == 42 # TODO: This isn't actually respected yet because the `--random-init` only # applies to a temporary Optimizer used to populate the initial values via # random sampling. # assert launcher.optimizer.seed == 1234 -if __name__ == "__main__": +if __name__ == '__main__': pytest.main([__file__, "-n1"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 508923f37d..591501d275 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -31,24 +31,16 @@ def local_exec_service() -> LocalExecService: """ Test fixture for LocalExecService. """ - return LocalExecService( - parent=ConfigPersistenceService( - { - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - } - ) - ) + return LocalExecService(parent=ConfigPersistenceService({ + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + })) -def _launch_main_app( - root_path: str, - local_exec_service: LocalExecService, - cli_config: str, - re_expected: List[str], -) -> None: +def _launch_main_app(root_path: str, local_exec_service: LocalExecService, + cli_config: str, re_expected: List[str]) -> None: """ Run mlos_bench command-line application with given config and check the results in the log. @@ -60,13 +52,10 @@ def _launch_main_app( # temp_dir = '/tmp' log_path = path_join(temp_dir, "mock-test.log") (return_code, _stdout, _stderr) = local_exec_service.local_exec( - [ - "./mlos_bench/mlos_bench/run.py" - + " --config_path ./mlos_bench/mlos_bench/tests/config/" - + f" {cli_config} --log_file '{log_path}'" - ], - cwd=root_path, - ) + ["./mlos_bench/mlos_bench/run.py" + + " --config_path ./mlos_bench/mlos_bench/tests/config/" + + f" {cli_config} --log_file '{log_path}'"], + cwd=root_path) assert return_code == 0 try: @@ -84,73 +73,65 @@ def _launch_main_app( _RE_DATE = r"\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}" -def test_launch_main_app_bench( - root_path: str, local_exec_service: LocalExecService -) -> None: +def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecService) -> None: """ Run mlos_bench command-line application with mock benchmark config and default tunable values and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - " --config cli/mock-bench.jsonc" - + " --trial_config_repeat_count 5" - + " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, local_exec_service, + " --config cli/mock-bench.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " - + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", + ] ) def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService -) -> None: + root_path: str, local_exec_service: LocalExecService) -> None: """ Run mlos_bench command-line application with mock benchmark config and user-specified tunable values and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - " --config cli/mock-bench.jsonc" - + " --tunable_values tunable-values/tunable-values-example.jsonc" - + " --trial_config_repeat_count 5" - + " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, local_exec_service, + " --config cli/mock-bench.jsonc" + + " --tunable_values tunable-values/tunable-values-example.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " - + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", + ] ) -def test_launch_main_app_opt( - root_path: str, local_exec_service: LocalExecService -) -> None: +def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecService) -> None: """ Run mlos_bench command-line application with mock optimization config and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - "--config cli/mock-opt.jsonc" - + " --trial_config_repeat_count 3" - + " --max_suggestions 3" - + " --mock_env_seed 42", # Noisy Mock Environment. + root_path, local_exec_service, + "--config cli/mock-opt.jsonc" + + " --trial_config_repeat_count 3" + + " --max_suggestions 3" + + " --mock_env_seed 42", # Noisy Mock Environment. [ # Iteration 1: Expect first value to be the baseline - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", # Iteration 2: The result may not always be deterministic - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Iteration 3: non-deterministic (depends on the optimizer) - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Final result: baseline is the optimum for the mock environment - f"^{_RE_DATE} run\\.py:\\d+ " - + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", + ] ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 6e22350d00..59a0fac13b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -23,29 +23,29 @@ def mock_configs() -> List[dict]: """ return [ { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 50000, - "kernel_sched_latency_ns": 1000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'halt', + 'kernel_sched_migration_cost_ns': 50000, + 'kernel_sched_latency_ns': 1000000, }, { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000, - "kernel_sched_latency_ns": 2000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'halt', + 'kernel_sched_migration_cost_ns': 40000, + 'kernel_sched_latency_ns': 2000000, }, { - "vmSize": "Standard_B4ms", - "idle": "mwait", - "kernel_sched_migration_cost_ns": -1, # Special value - "kernel_sched_latency_ns": 3000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'mwait', + 'kernel_sched_migration_cost_ns': -1, # Special value + 'kernel_sched_latency_ns': 3000000, }, { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_latency_ns": 4000000, - }, + 'vmSize': 'Standard_B2s', + 'idle': 'mwait', + 'kernel_sched_migration_cost_ns': 200000, + 'kernel_sched_latency_ns': 4000000, + } ] @@ -61,7 +61,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: "optimization_targets": {"score": "min"}, "max_suggestions": 5, "start_with_defaults": False, - "seed": SEED, + "seed": SEED }, ) @@ -77,7 +77,7 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: config={ "optimization_targets": {"score": "min"}, "max_suggestions": 5, - "seed": SEED, + "seed": SEED }, ) @@ -93,7 +93,7 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: config={ "optimization_targets": {"score": "max"}, "max_suggestions": 10, - "seed": SEED, + "seed": SEED }, ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index cceac9099b..9e9ce25d6f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -20,7 +20,6 @@ # pylint: disable=redefined-outer-name - @pytest.fixture def grid_search_tunables_config() -> dict: """ @@ -52,27 +51,14 @@ def grid_search_tunables_config() -> dict: @pytest.fixture -def grid_search_tunables_grid( - grid_search_tunables: TunableGroups, -) -> List[Dict[str, TunableValue]]: +def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. Used to check that the grids are the same (ignoring order). """ - tunables_params_values = [ - tunable.values - for tunable, _group in grid_search_tunables - if tunable.values is not None - ] - tunable_names = tuple( - tunable.name - for tunable, _group in grid_search_tunables - if tunable.values is not None - ) - return list( - dict(zip(tunable_names, combo)) - for combo in itertools.product(*tunables_params_values) - ) + tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] + tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) + return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) @pytest.fixture @@ -84,10 +70,8 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture -def grid_search_opt( - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]], -) -> GridSearchOptimizer: +def grid_search_opt(grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: """ Test fixture for grid search optimizer. """ @@ -95,27 +79,20 @@ def grid_search_opt( # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. max_iterations = len(grid_search_tunables_grid) * 2 - 3 - return GridSearchOptimizer( - tunables=grid_search_tunables, - config={ - "max_suggestions": max_iterations, - "optimization_targets": {"score": "max", "other_score": "min"}, - }, - ) + return GridSearchOptimizer(tunables=grid_search_tunables, config={ + "max_suggestions": max_iterations, + "optimization_targets": {"score": "max", "other_score": "min"}, + }) -def test_grid_search_grid( - grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]], -) -> None: +def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ # Check the size. - expected_grid_size = math.prod( - tunable.cardinality for tunable, _group in grid_search_tunables - ) + expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables) assert expected_grid_size > len(grid_search_tunables) assert len(grid_search_tunables_grid) == expected_grid_size # Check for specific example configs inclusion. @@ -131,23 +108,15 @@ def test_grid_search_grid( # Note: ConfigSpace param name vs TunableGroup parameter name order is not # consistent, so we need to full dict comparison. assert len(grid_search_opt_pending_configs) == expected_grid_size - assert all( - config in grid_search_tunables_grid - for config in grid_search_opt_pending_configs - ) - assert all( - config in grid_search_opt_pending_configs - for config in grid_search_tunables_grid - ) + assert all(config in grid_search_tunables_grid for config in grid_search_opt_pending_configs) + assert all(config in grid_search_opt_pending_configs for config in grid_search_tunables_grid) # Order is less relevant to us, so we'll just check that the sets are the same. # assert grid_search_opt.pending_configs == grid_search_tunables_grid -def test_grid_search( - grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]], -) -> None: +def test_grid_search(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -173,14 +142,8 @@ def test_grid_search( grid_search_tunables_grid.remove(default_config) assert default_config not in grid_search_opt.pending_configs - assert all( - config in grid_search_tunables_grid - for config in grid_search_opt.pending_configs - ) - assert all( - config in list(grid_search_opt.pending_configs) - for config in grid_search_tunables_grid - ) + assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) + assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) # The next suggestion should be a different element in the grid search. suggestion = grid_search_opt.suggest() @@ -193,14 +156,8 @@ def test_grid_search( assert suggestion_dict not in grid_search_opt.suggested_configs grid_search_tunables_grid.remove(suggestion.get_param_values()) - assert all( - config in grid_search_tunables_grid - for config in grid_search_opt.pending_configs - ) - assert all( - config in list(grid_search_opt.pending_configs) - for config in grid_search_tunables_grid - ) + assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) + assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) # We consider not_converged as either having reached "max_suggestions" or an empty grid? @@ -266,7 +223,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: assert best_suggestion_dict not in grid_search_opt.suggested_configs best_suggestion_score: Dict[str, TunableValue] = {} - for opt_target, opt_dir in grid_search_opt.targets.items(): + for (opt_target, opt_dir) in grid_search_opt.targets.items(): val = score[opt_target] assert isinstance(val, (int, float)) best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 @@ -280,57 +237,36 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: # Check bulk register suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] - assert all( - suggestion.get_param_values() not in grid_search_opt.pending_configs - for suggestion in suggested - ) - assert all( - suggestion.get_param_values() in grid_search_opt.suggested_configs - for suggestion in suggested - ) + assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) + assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) # Those new suggestions also shouldn't be in the set of previously suggested configs. - assert all( - suggestion.get_param_values() not in suggested_shuffled - for suggestion in suggested - ) - - grid_search_opt.bulk_register( - [suggestion.get_param_values() for suggestion in suggested], - [score] * len(suggested), - [status] * len(suggested), - ) - - assert all( - suggestion.get_param_values() not in grid_search_opt.pending_configs - for suggestion in suggested - ) - assert all( - suggestion.get_param_values() not in grid_search_opt.suggested_configs - for suggestion in suggested - ) + assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) + + grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], + [score] * len(suggested), + [status] * len(suggested)) + + assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) + assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) best_score, best_config = grid_search_opt.get_best_observation() assert best_score == best_suggestion_score assert best_config == best_suggestion -def test_grid_search_register( - grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups -) -> None: +def test_grid_search_register(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups) -> None: """ Make sure that the `.register()` method adjusts the score signs correctly. """ assert grid_search_opt.register( - grid_search_tunables, - Status.SUCCEEDED, - { + grid_search_tunables, Status.SUCCEEDED, { "score": 1.0, "other_score": 2.0, - }, - ) == { - "score": -1.0, # max - "other_score": 2.0, # min + }) == { + "score": -1.0, # max + "other_score": 2.0, # min } assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 07eec4655f..6549a8795c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -34,8 +34,7 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: "optimizer_type": "SMAC", "seed": SEED, # "start_with_defaults": False, - }, - ) + }) @pytest.fixture @@ -46,9 +45,7 @@ def mock_scores() -> list: return [88.88, 66.66, 99.99] -def test_llamatune_optimizer( - llamatune_opt: MlosCoreOptimizer, mock_scores: list -) -> None: +def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None: """ Make sure that llamatune+smac optimizer initializes and works correctly. """ @@ -64,6 +61,6 @@ def test_llamatune_optimizer( assert best_score["score"] == pytest.approx(66.66, 0.01) -if __name__ == "__main__": +if __name__ == '__main__': # For attaching debugger debugging: pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__]) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index c824d9774f..7ebba0e664 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -24,9 +24,9 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: An instance of a mlos_core optimizer (FLAML-based). """ test_opt_config = { - "optimizer_type": "FLAML", - "max_suggestions": 10, - "seed": SEED, + 'optimizer_type': 'FLAML', + 'max_suggestions': 10, + 'seed': SEED, } return MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -39,44 +39,44 @@ def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) assert set(df_config.columns) == { - "kernel_sched_latency_ns", - "kernel_sched_migration_cost_ns", - "kernel_sched_migration_cost_ns!type", - "kernel_sched_migration_cost_ns!special", - "idle", - "vmSize", + 'kernel_sched_latency_ns', + 'kernel_sched_migration_cost_ns', + 'kernel_sched_migration_cost_ns!type', + 'kernel_sched_migration_cost_ns!special', + 'idle', + 'vmSize', } - assert df_config.to_dict(orient="records") == [ + assert df_config.to_dict(orient='records') == [ { - "idle": "halt", - "kernel_sched_latency_ns": 1000000, - "kernel_sched_migration_cost_ns": 50000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B4ms", + 'idle': 'halt', + 'kernel_sched_latency_ns': 1000000, + 'kernel_sched_migration_cost_ns': 50000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B4ms', }, { - "idle": "halt", - "kernel_sched_latency_ns": 2000000, - "kernel_sched_migration_cost_ns": 40000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B4ms", + 'idle': 'halt', + 'kernel_sched_latency_ns': 2000000, + 'kernel_sched_migration_cost_ns': 40000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B4ms', }, { - "idle": "mwait", - "kernel_sched_latency_ns": 3000000, - "kernel_sched_migration_cost_ns": None, # The value is special! - "kernel_sched_migration_cost_ns!special": -1, - "kernel_sched_migration_cost_ns!type": "special", - "vmSize": "Standard_B4ms", + 'idle': 'mwait', + 'kernel_sched_latency_ns': 3000000, + 'kernel_sched_migration_cost_ns': None, # The value is special! + 'kernel_sched_migration_cost_ns!special': -1, + 'kernel_sched_migration_cost_ns!type': 'special', + 'vmSize': 'Standard_B4ms', }, { - "idle": "mwait", - "kernel_sched_latency_ns": 4000000, - "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B2s", + 'idle': 'mwait', + 'kernel_sched_latency_ns': 4000000, + 'kernel_sched_migration_cost_ns': 200000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B2s', }, ] diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index 95d51cbe22..fc62b4ff1b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -17,8 +17,8 @@ from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/" -_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test. +_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' +_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test. def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: @@ -26,10 +26,10 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) Test invalid max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "max_trials": 10, - "max_suggestions": 11, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'max_trials': 10, + 'max_suggestions': 11, + 'seed': SEED, } with pytest.raises(AssertionError): opt = MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -41,29 +41,25 @@ def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> No Test max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "max_suggestions": 123, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'max_suggestions': 123, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) - assert ( - opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"] - ) + assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions'] -def test_init_mlos_core_smac_absolute_output_directory( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: """ Test absolute path output directory initialization of mlos_core SMAC optimizer. """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": output_dir, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': output_dir, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) @@ -71,96 +67,76 @@ def test_init_mlos_core_smac_absolute_output_directory( assert isinstance(opt._opt, SmacOptimizer) # Final portions of the path are generated by SMAC when run_name is not specified. assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - str(test_opt_config["output_directory"]) - ) + str(test_opt_config['output_directory'])) shutil.rmtree(output_dir) -def test_init_mlos_core_smac_relative_output_directory( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGroups) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config["output_directory"])) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']))) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_run_name( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "run_name": "test_run", - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'run_name': 'test_run', + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join( - os.getcwd(), - str(test_opt_config["output_directory"]), - str(test_opt_config["run_name"]), - ) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name']))) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'seed': SEED, } global_config = { - "experiment_id": "experiment_id", + 'experiment_id': 'experiment_id', } opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join( - os.getcwd(), - str(test_opt_config["output_directory"]), - global_config["experiment_id"], - ) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id'])) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_temp_output_directory( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups) -> None: """ Test random output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": None, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': None, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index 739e27114b..a94a315939 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -20,33 +20,24 @@ def mock_configurations_no_defaults() -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, - }, - 88.88, - ), - ( - { - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 117026, - "kernel_sched_latency_ns": 149827706, - }, - 66.66, - ), - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 354785, - "kernel_sched_latency_ns": 795285932, - }, - 99.99, - ), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 13112, + "kernel_sched_latency_ns": 796233790, + }, 88.88), + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 117026, + "kernel_sched_latency_ns": 149827706, + }, 66.66), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 354785, + "kernel_sched_latency_ns": 795285932, + }, 99.99), ] @@ -56,15 +47,12 @@ def mock_configurations(mock_configurations_no_defaults: list) -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 2000000, - }, - 88.88, - ), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": -1, + "kernel_sched_latency_ns": 2000000, + }, 88.88), ] + mock_configurations_no_defaults @@ -72,7 +60,7 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """ Run several iterations of the optimizer and return the best score. """ - for tunable_values, score in mock_configurations: + for (tunable_values, score) in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() assert tunables.get_param_values() == tunable_values @@ -92,9 +80,8 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_no_defaults( - mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list -) -> None: +def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, + mock_configurations_no_defaults: list) -> None: """ Make sure that mock optimizer produces consistent suggestions. """ @@ -102,9 +89,7 @@ def test_mock_optimizer_no_defaults( assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_max( - mock_opt_max: MockOptimizer, mock_configurations: list -) -> None: +def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None: """ Check the maximization mode of the mock optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index ccc0ba8137..bf37040f13 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -25,7 +25,10 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: Same as `mock_config` above, but with all values converted to strings. (This can happen when we retrieve the data from storage). """ - return [{key: str(val) for (key, val) in config.items()} for config in mock_configs] + return [ + {key: str(val) for (key, val) in config.items()} + for config in mock_configs + ] @pytest.fixture @@ -49,12 +52,10 @@ def mock_status() -> List[Status]: return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] -def _test_opt_update_min( - opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None, -) -> None: +def _test_opt_update_min(opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None) -> None: """ Test the bulk update of the optimizer on the minimization problem. """ @@ -67,16 +68,14 @@ def _test_opt_update_min( "vmSize": "Standard_B4ms", "idle": "mwait", "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 3000000, + 'kernel_sched_latency_ns': 3000000, } -def _test_opt_update_max( - opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None, -) -> None: +def _test_opt_update_max(opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None) -> None: """ Test the bulk update of the optimizer on the maximization problem. """ @@ -89,16 +88,14 @@ def _test_opt_update_max( "vmSize": "Standard_B2s", "idle": "mwait", "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_latency_ns": 4000000, + 'kernel_sched_latency_ns': 4000000, } -def test_update_mock_min( - mock_opt: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_min(mock_opt: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the minimization problem. """ @@ -108,76 +105,64 @@ def test_update_mock_min( "vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, + 'kernel_sched_latency_ns': 796233790, } -def test_update_mock_min_str( - mock_opt: MockOptimizer, - mock_configs_str: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_min_str(mock_opt: MockOptimizer, + mock_configs_str: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer with all-strings data. """ _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) -def test_update_mock_max( - mock_opt_max: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_max(mock_opt_max: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the maximization problem. """ _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) -def test_update_flaml( - flaml_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_flaml(flaml_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) -def test_update_flaml_max( - flaml_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) -def test_update_smac( - smac_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_smac(smac_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the SMAC optimizer. """ _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) -def test_update_smac_max( - smac_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index d5068e0656..2a50f95e8c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -42,16 +42,12 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: logger("tunables: %s", str(tunables)) # pylint: disable=protected-access - if isinstance(opt, MlosCoreOptimizer) and isinstance( - opt._opt, SmacOptimizer - ): + if isinstance(opt, MlosCoreOptimizer) and isinstance(opt._opt, SmacOptimizer): config = tunable_values_to_configuration(tunables) config_df = config_to_dataframe(config) logger("config: %s", str(config)) try: - logger( - "prediction: %s", opt._opt.surrogate_predict(configs=config_df) - ) + logger("prediction: %s", opt._opt.surrogate_predict(configs=config_df)) except RuntimeError: pass @@ -60,7 +56,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: (status, _ts, output) = env_context.run() assert status.is_succeeded() assert output is not None - score = output["score"] + score = output['score'] assert isinstance(score, float) assert 60 <= score <= 120 logger("score: %s", str(score)) @@ -73,9 +69,8 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: return (best_score["score"], best_tunables) -def test_mock_optimization_loop( - mock_env_no_noise: MockEnv, mock_opt: MockOptimizer -) -> None: +def test_mock_optimization_loop(mock_env_no_noise: MockEnv, + mock_opt: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -89,9 +84,8 @@ def test_mock_optimization_loop( } -def test_mock_optimization_loop_no_defaults( - mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer -) -> None: +def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, + mock_opt_no_defaults: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -105,9 +99,8 @@ def test_mock_optimization_loop_no_defaults( } -def test_flaml_optimization_loop( - mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer -) -> None: +def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, + flaml_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and FLAML optimizer. """ @@ -122,9 +115,8 @@ def test_flaml_optimization_loop( # @pytest.mark.skip(reason="SMAC is not deterministic") -def test_smac_optimization_loop( - mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer -) -> None: +def test_smac_optimization_loop(mock_env_no_noise: MockEnv, + smac_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index bf4df0e6c2..1971c01799 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -11,8 +11,8 @@ from .remote import MockFileShareService, MockRemoteExecService, MockVMService __all__ = [ - "MockLocalExecService", - "MockFileShareService", - "MockRemoteExecService", - "MockVMService", + 'MockLocalExecService', + 'MockFileShareService', + 'MockRemoteExecService', + 'MockVMService', ] diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index 8f51dd9f85..d6cb869f09 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -29,24 +29,18 @@ def config_persistence_service() -> ConfigPersistenceService: """ Test fixture for ConfigPersistenceService. """ - return ConfigPersistenceService( - { - "config_path": [ - "./non-existent-dir/test/foo/bar", # Non-existent config path - ".", # cwd - str( - files("mlos_bench.tests.config").joinpath("") - ), # Test configs (relative to mlos_bench/tests) - # Shouldn't be necessary since we automatically add this. - # str(files("mlos_bench.config").joinpath("")), # Stock configs - ] - } - ) - - -def test_cwd_in_explicit_search_path( - config_persistence_service: ConfigPersistenceService, -) -> None: + return ConfigPersistenceService({ + "config_path": [ + "./non-existent-dir/test/foo/bar", # Non-existent config path + ".", # cwd + str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) + # Shouldn't be necessary since we automatically add this. + # str(files("mlos_bench.config").joinpath("")), # Stock configs + ] + }) + + +def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: """ Check that CWD is in the search path in the correct place. """ @@ -71,25 +65,20 @@ def test_cwd_in_default_search_path() -> None: config_persistence_service._config_path.index(cwd, 1) -def test_resolve_stock_path( - config_persistence_service: ConfigPersistenceService, -) -> None: +def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None: """ Check if we can actually find a file somewhere in `config_path`. """ # pylint: disable=protected-access assert config_persistence_service._config_path is not None - assert ( - ConfigPersistenceService.BUILTIN_CONFIG_PATH - in config_persistence_service._config_path - ) + assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path file_path = "storage/in-memory.jsonc" path = config_persistence_service.resolve_path(file_path) assert path.endswith(file_path) assert os.path.exists(path) assert os.path.samefile( ConfigPersistenceService.BUILTIN_CONFIG_PATH, - os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]), + os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) ) @@ -103,9 +92,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> N assert os.path.exists(path) -def test_resolve_path_fail( - config_persistence_service: ConfigPersistenceService, -) -> None: +def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None: """ Check if non-existent file resolves without using `config_path`. """ @@ -119,9 +106,8 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No """ Check if we can successfully load a config file located relative to `config_path`. """ - tunables_data = config_persistence_service.load_config( - "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES - ) + tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", + ConfigSchema.TUNABLE_VALUES) assert tunables_data is not None assert isinstance(tunables_data, dict) assert len(tunables_data) >= 1 diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index a09fd442fb..c6dbf7c021 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -10,5 +10,5 @@ from .mock import MockLocalExecService __all__ = [ - "MockLocalExecService", + 'MockLocalExecService', ] diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index dafd8ed2fe..572195dcc5 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -56,22 +56,17 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: json.dump(params_meta, fh_meta) script_path = local_exec_service.config_loader_service.resolve_path( - "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py" - ) + "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py") - (return_code, _stdout, stderr) = local_exec_service.local_exec( - [f"{script_path} {input_file} {meta_file} {output_file}"], - cwd=temp_dir, - env=params, - ) + (return_code, _stdout, stderr) = local_exec_service.local_exec([ + f"{script_path} {input_file} {meta_file} {output_file}" + ], cwd=temp_dir, env=params) assert stderr.strip() == "" assert return_code == 0 # assert stdout.strip() == "" - with open( - path_join(temp_dir, output_file), "rt", encoding="utf-8" - ) as fh_output: + with open(path_join(temp_dir, output_file), "rt", encoding="utf-8") as fh_output: assert [ln.strip() for ln in fh_output.readlines()] == [ 'echo "40000" > /proc/sys/kernel/sched_migration_cost_ns', 'echo "800000" > /proc/sys/kernel/sched_granularity_ns', diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index 04f1f600f3..bd5b3b7d7f 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -26,23 +26,23 @@ def test_split_cmdline() -> None: """ cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" assert list(split_cmdline(cmdline)) == [ - [".", "env.sh"], - ["&&"], - ["("], - ["echo", "hello"], - ["&&"], - ["echo", "world"], - ["|"], - ["tee"], - [">"], - ["/tmp/test"], - ["||"], - ["echo", "foo"], - ["&&"], - ["echo", "$var"], - [";"], - ["true"], - [")"], + ['.', 'env.sh'], + ['&&'], + ['('], + ['echo', 'hello'], + ['&&'], + ['echo', 'world'], + ['|'], + ['tee'], + ['>'], + ['/tmp/test'], + ['||'], + ['echo', 'foo'], + ['&&'], + ['echo', '$var'], + [';'], + ['true'], + [')'], ] @@ -67,13 +67,8 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: expected_cmdline = f". env.sh && {script_abspath} --input foo" subcmds_tokens = split_cmdline(orig_cmdline) # pylint: disable=protected-access - subcmds_tokens = [ - local_exec_service._resolve_cmdline_script_path(subcmd_tokens) - for subcmd_tokens in subcmds_tokens - ] - cmdline_tokens = [ - token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens - ] + subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] + cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] expanded_cmdline = " ".join(cmdline_tokens) assert expanded_cmdline == expected_cmdline @@ -94,9 +89,10 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: Run a multiline script locally and check the results. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec( - ["echo hello", "echo world"] - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + "echo hello", + "echo world" + ]) assert return_code == 0 assert stdout.strip().split() == ["hello", "world"] assert stderr.strip() == "" @@ -107,12 +103,12 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: Run a multiline script locally and pass the environment variables to it. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec( - [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd - env={"var": "VALUE", "int_var": 10}, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + r"echo $var", # Unix shell + r"echo %var%" # Windows cmd + ], env={"var": "VALUE", "int_var": 10}) assert return_code == 0 - if sys.platform == "win32": + if sys.platform == 'win32': assert stdout.strip().split() == ["$var", "VALUE"] else: assert stdout.strip().split() == ["VALUE", "%var%"] @@ -125,26 +121,23 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """ with local_exec_service.temp_dir_context() as temp_dir: - (return_code, stdout, stderr) = local_exec_service.local_exec( - [ - "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows - "echo '111,222' >> output.csv", - "echo '333,444' >> output.csv", - ], - cwd=temp_dir, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows + "echo '111,222' >> output.csv", + "echo '333,444' >> output.csv", + ], cwd=temp_dir) assert return_code == 0 assert stdout.strip() == "" assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) - if sys.platform == "win32": + if sys.platform == 'win32': # Workaround for Python's subprocess module on Windows adding a # space inbetween the col1,col2 arg and the redirect symbol which # cmd poorly interprets as being part of the original string arg. # Without this, we get "col2 " as the second column name. - data.rename(str.rstrip, axis="columns", inplace=True) + data.rename(str.rstrip, axis='columns', inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444]) @@ -159,13 +152,10 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input: fh_input.write("hello\n") - (return_code, stdout, stderr) = local_exec_service.local_exec( - [ - f"echo 'world' >> {input_file}", - f"echo 'test' >> {input_file}", - ], - cwd=temp_dir, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + f"echo 'world' >> {input_file}", + f"echo 'test' >> {input_file}", + ], cwd=temp_dir) assert return_code == 0 assert stdout.strip() == "" @@ -179,9 +169,7 @@ def test_run_script_fail(local_exec_service: LocalExecService) -> None: """ Try to run a non-existent command. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec( - ["foo_bar_baz hello"] - ) + (return_code, stdout, _stderr) = local_exec_service.local_exec(["foo_bar_baz hello"]) assert return_code != 0 assert stdout.strip() == "" @@ -190,13 +178,11 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N """ Try to run a series of commands, one of which fails, and abort early. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec( - [ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == "win32" else "false", - "echo world", - ] - ) + (return_code, stdout, _stderr) = local_exec_service.local_exec([ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", + "echo world", + ]) assert return_code != 0 assert stdout.strip() == "hello" @@ -206,13 +192,11 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No Try to run a series of commands, one of which fails, but let it pass. """ local_exec_service.abort_on_error = False - (return_code, stdout, _stderr) = local_exec_service.local_exec( - [ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == "win32" else "false", - "echo world", - ] - ) + (return_code, stdout, _stderr) = local_exec_service.local_exec([ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", + "echo world", + ]) assert return_code == 0 assert stdout.splitlines() == [ "hello", @@ -230,17 +214,13 @@ def test_temp_dir_path_expansion() -> None: # the fact. with tempfile.TemporaryDirectory() as temp_dir: global_config = { - "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" + "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" } config = { # The temp_dir for the LocalExecService should get expanded via workdir global config. "temp_dir": "$workdir/temp", } - local_exec_service = LocalExecService( - config, global_config, parent=ConfigPersistenceService() - ) + local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService()) # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) - assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( - temp_dir, "temp", abs_path=True - ) + assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True) diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index 9164da60df..eede9383bc 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -9,5 +9,5 @@ from .mock_local_exec_service import MockLocalExecService __all__ = [ - "MockLocalExecService", + 'MockLocalExecService', ] diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index ad47160753..db8f0134c4 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -35,24 +35,16 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): Mock methods for LocalExecService testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.local_exec]), + config, global_config, parent, + self.merge_methods(methods, [self.local_exec]) ) - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: return (0, "", "") diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 4ef38ab440..835738015b 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -28,24 +28,19 @@ class MockServiceBase(Service, SupportsSomeMethod): """A base service class for testing.""" def __init__( - self, - config: Optional[dict] = None, - global_config: Optional[dict] = None, - parent: Optional[Service] = None, - methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None, - ) -> None: + self, + config: Optional[dict] = None, + global_config: Optional[dict] = None, + parent: Optional[Service] = None, + methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None: super().__init__( config, global_config, parent, - self.merge_methods( - methods, - [ - self.some_method, - self.some_other_method, - ], - ), - ) + self.merge_methods(methods, [ + self.some_method, + self.some_other_method, + ])) def some_method(self) -> str: """some_method""" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index df3fb69c53..e8a87ab684 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -12,7 +12,7 @@ from .mock.mock_vm_service import MockVMService __all__ = [ - "MockFileShareService", - "MockRemoteExecService", - "MockVMService", + 'MockFileShareService', + 'MockRemoteExecService', + 'MockVMService', ] diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 64633a534b..c6475e6936 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -18,25 +18,16 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_file( - mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object( - mock_share_client, "get_file_client" - ) as mock_get_file_client, patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client: + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ + patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -56,45 +47,38 @@ def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock( - return_value=[ - {"name": "a_folder", "is_directory": True}, - {"name": "a_file_1.csv", "is_directory": False}, - ] - ), + list_directories_and_files=Mock(return_value=[ + {"name": "a_folder", "is_directory": True}, + {"name": "a_file_1.csv", "is_directory": False}, + ]) ), f"{remote_folder}/a_folder": Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock( - return_value=[ - {"name": "a_file_2.csv", "is_directory": False}, - ] - ), + list_directories_and_files=Mock(return_value=[ + {"name": "a_file_2.csv", "is_directory": False}, + ]) + ), + f"{remote_folder}/a_file_1.csv": Mock( + exists=Mock(return_value=False) + ), + f"{remote_folder}/a_folder/a_file_2.csv": Mock( + exists=Mock(return_value=False) ), - f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)), - f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)), } @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive( - mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_download_folder_non_recursive(mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client, patch.object( - mock_share_client, "get_file_client" - ) as mock_get_file_client: + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] @@ -103,69 +87,47 @@ def test_download_folder_non_recursive( mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_get_directory_client.assert_has_calls( - [ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - ], - any_order=True, - ) + mock_get_directory_client.assert_has_calls([ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + ], any_order=True) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive( - mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client, patch.object( - mock_share_client, "get_file_client" - ) as mock_get_file_client: + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] azure_fileshare.download(config, remote_folder, local_folder, recursive=True) - mock_get_file_client.assert_has_calls( - [ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) - mock_get_directory_client.assert_has_calls( - [ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) + mock_get_file_client.assert_has_calls([ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) + mock_get_directory_client.assert_has_calls([ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file( - mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService -) -> None: +def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False config: dict = {} @@ -181,7 +143,6 @@ def test_upload_file( class MyDirEntry: # pylint: disable=too-few-public-methods """Dummy class for os.DirEntry""" - def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir @@ -225,21 +186,17 @@ def process_paths(input_path: str) -> str: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive( - mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_upload_directory_non_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: @@ -251,30 +208,23 @@ def test_upload_directory_non_recursive( @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive( - mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_upload_directory_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = ( - azure_fileshare._share_client - ) # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) - mock_get_file_client.assert_has_calls( - [ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) + mock_get_file_client.assert_has_calls([ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index 7a7a87359a..d6d55d3975 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -18,41 +18,27 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_network_deployment_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_network_service: AzureNetworkService, -) -> None: +def test_wait_network_deployment_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_network_service: AzureNetworkService) -> None: """ Test retries of the network deployment operation. """ # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Running"}} - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Running"}} - ), - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Succeeded"}} - ), + make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), + make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] (status, _) = azure_network_service.wait_network_deployment( @@ -63,37 +49,30 @@ def test_wait_network_deployment_retry( "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True, - ) + is_setup=True) assert status == operation_status @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("deprovision_network", True), - ], -) + ]) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), # These should succeed since we set ignore_errors=True by default (401, Status.SUCCEEDED), (404, Status.SUCCEEDED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_network_operation_status( - mock_requests: MagicMock, - azure_network_service: AzureNetworkService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status, -) -> None: +def test_network_operation_status(mock_requests: MagicMock, + azure_network_service: AzureNetworkService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status) -> None: """ Test network operation status. """ @@ -105,37 +84,27 @@ def test_network_operation_status( with pytest.raises(ValueError): # Missing vnetName should raise ValueError (status, _) = operation({}) if accepts_params else operation() - (status, _) = ( - operation({"vnetName": "test-vnet"}) if accepts_params else operation() - ) + (status, _) = operation({"vnetName": "test-vnet"}) if accepts_params else operation() assert status == operation_status @pytest.fixture -def test_azure_network_service_no_deployment_template( - azure_auth_service: AzureAuthService, -) -> None: +def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: """ Tests creating a network services without a deployment template (should fail). """ with pytest.raises(ValueError): - _ = AzureNetworkService( - config={ - "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", - }, + _ = AzureNetworkService(config={ + "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + }, parent=azure_auth_service) with pytest.raises(ValueError): - _ = AzureNetworkService( - config={ - # "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", - }, + _ = AzureNetworkService(config={ + # "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + }, parent=azure_auth_service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 0fd94cf821..1d84d73cab 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -19,41 +19,27 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_host_deployment_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService, -) -> None: +def test_wait_host_deployment_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService) -> None: """ Test retries of the host deployment operation. """ # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Running"}} - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Running"}} - ), - make_httplib_json_response( - 200, {"properties": {"provisioningState": "Succeeded"}} - ), + make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), + make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] (status, _) = azure_vm_service.wait_host_deployment( @@ -64,14 +50,11 @@ def test_wait_host_deployment_retry( "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True, - ) + is_setup=True) assert status == operation_status -def test_azure_vm_service_recursive_template_params( - azure_auth_service: AzureAuthService, -) -> None: +def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: """ Test expanding template params recursively. """ @@ -92,14 +75,8 @@ def test_azure_vm_service_recursive_template_params( } azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) assert azure_vm_service.deploy_params["location"] == global_config["location"] - assert ( - azure_vm_service.deploy_params["vmMeta"] - == f'{global_config["vmName"]}-{global_config["location"]}' - ) - assert ( - azure_vm_service.deploy_params["vmNsg"] - == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' - ) + assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' + assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: @@ -121,17 +98,14 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N } with pytest.raises(ValueError): config_with_custom_data = deepcopy(config) - config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] - AzureVMService( - config_with_custom_data, global_config, parent=azure_auth_service - ) + config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] + AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) - assert azure_vm_service.deploy_params["customData"] + assert azure_vm_service.deploy_params['customData'] @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("start_host", True), ("stop_host", True), ("shutdown", True), @@ -139,27 +113,22 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N ("deallocate_host", True), ("restart_host", True), ("reboot", True), - ], -) + ]) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_vm_operation_status( - mock_requests: MagicMock, - azure_vm_service: AzureVMService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status, -) -> None: +def test_vm_operation_status(mock_requests: MagicMock, + azure_vm_service: AzureVMService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status) -> None: """ Test VM operation status. """ @@ -176,16 +145,12 @@ def test_vm_operation_status( @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("provision_host", True), - ], -) -def test_vm_operation_invalid( - azure_vm_service_remote_exec_only: AzureVMService, - operation_name: str, - accepts_params: bool, -) -> None: + ]) +def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, + operation_name: str, + accepts_params: bool) -> None: """ Test VM operation status for an incomplete service config. """ @@ -196,9 +161,8 @@ def test_vm_operation_invalid( @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_ready( - mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService -) -> None: +def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, + azure_vm_service: AzureVMService) -> None: """ Test waiting for the completion of the remote VM operation. """ @@ -219,15 +183,14 @@ def test_wait_vm_operation_ready( status, _ = azure_vm_service.wait_host_operation(params) - assert (async_url,) == mock_session.return_value.get.call_args[0] - assert (retry_after,) == mock_sleep.call_args[0] + assert (async_url, ) == mock_session.return_value.get.call_args[0] + assert (retry_after, ) == mock_sleep.call_args[0] assert status.is_succeeded() @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_timeout( - mock_session: MagicMock, azure_vm_service: AzureVMService -) -> None: +def test_wait_vm_operation_timeout(mock_session: MagicMock, + azure_vm_service: AzureVMService) -> None: """ Test the time out of the remote VM operation. """ @@ -235,7 +198,7 @@ def test_wait_vm_operation_timeout( params = { "asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", - "pollInterval": 1, + "pollInterval": 1 } mock_status_response = MagicMock(status_code=200) @@ -249,20 +212,16 @@ def test_wait_vm_operation_timeout( @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_vm_operation_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService, -) -> None: +def test_wait_vm_operation_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService) -> None: """ Test the retries of the remote VM operation. """ @@ -270,12 +229,8 @@ def test_wait_vm_operation_retry( # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), ] @@ -286,27 +241,20 @@ def test_wait_vm_operation_retry( "requestTotalRetries": total_retries, "asyncResultsUrl": "https://DUMMY_ASYNC_URL", "vmName": "test-vm", - } - ) + }) assert status == operation_status @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_status( - mock_requests: MagicMock, - azure_vm_service_remote_exec_only: AzureVMService, - http_status_code: int, - operation_status: Status, -) -> None: +def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, + http_status_code: int, operation_status: Status) -> None: """ Test waiting for completion of the remote execution on Azure. """ @@ -314,24 +262,19 @@ def test_remote_exec_status( mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock( - return_value={ - "fake response": "body as json to dict", - } - ) + mock_response.json = MagicMock(return_value={ + "fake response": "body as json to dict", + }) mock_requests.post.return_value = mock_response - status, _ = azure_vm_service_remote_exec_only.remote_exec( - script, config={"vmName": "test-vm"}, env_params={} - ) + status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) assert status == operation_status @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output( - mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService -) -> None: +def test_remote_exec_headers_output(mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService) -> None: """ Check if HTTP headers from the remote execution on Azure are correct. """ @@ -341,22 +284,18 @@ def test_remote_exec_headers_output( mock_response = MagicMock() mock_response.status_code = 202 - mock_response.headers = {"Azure-AsyncOperation": async_url_value} - mock_response.json = MagicMock( - return_value={ - "fake response": "body as json to dict", - } - ) + mock_response.headers = { + "Azure-AsyncOperation": async_url_value + } + mock_response.json = MagicMock(return_value={ + "fake response": "body as json to dict", + }) mock_requests.post.return_value = mock_response - _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( - script, - config={"vmName": "test-vm"}, - env_params={ - "param_1": 123, - "param_2": "abc", - }, - ) + _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ + "param_1": 123, + "param_2": "abc", + }) assert async_url_key in cmd_output assert cmd_output[async_url_key] == async_url_value @@ -366,14 +305,13 @@ def test_remote_exec_headers_output( "script": script, "parameters": [ {"name": "param_1", "value": 123}, - {"name": "param_2", "value": "abc"}, - ], + {"name": "param_2", "value": "abc"} + ] } @pytest.mark.parametrize( - ("operation_status", "wait_output", "results_output"), - [ + ("operation_status", "wait_output", "results_output"), [ ( Status.SUCCEEDED, { @@ -385,18 +323,13 @@ def test_remote_exec_headers_output( } } }, - {"stdout": "DUMMY_STDOUT_STDERR"}, + {"stdout": "DUMMY_STDOUT_STDERR"} ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), - ], -) -def test_get_remote_exec_results( - azure_vm_service_remote_exec_only: AzureVMService, - operation_status: Status, - wait_output: dict, - results_output: dict, -) -> None: + ]) +def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, + wait_output: dict, results_output: dict) -> None: """ Test getting the results of the remote execution on Azure. """ @@ -405,15 +338,9 @@ def test_get_remote_exec_results( mock_wait_host_operation = MagicMock() mock_wait_host_operation.return_value = (operation_status, wait_output) # azure_vm_service.wait_host_operation = mock_wait_host_operation - setattr( - azure_vm_service_remote_exec_only, - "wait_host_operation", - mock_wait_host_operation, - ) - - status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results( - params - ) + setattr(azure_vm_service_remote_exec_only, "wait_host_operation", mock_wait_host_operation) + + status, cmd_output = azure_vm_service_remote_exec_only.get_remote_exec_results(params) assert status == operation_status assert mock_wait_host_operation.call_args[0][0] == params diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 1e997fc795..2794bb01cf 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -30,16 +30,12 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture -def azure_auth_service( - config_persistence_service: ConfigPersistenceService, - monkeypatch: pytest.MonkeyPatch, -) -> AzureAuthService: +def azure_auth_service(config_persistence_service: ConfigPersistenceService, + monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: """ Creates a dummy AzureAuthService for tests that require it. """ - auth = AzureAuthService( - config={}, global_config={}, parent=config_persistence_service - ) + auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service) monkeypatch.setattr(auth, "get_access_token", lambda: "TEST_TOKEN") return auth @@ -49,23 +45,19 @@ def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkS """ Creates a dummy Azure VM service for tests that require it. """ - return AzureNetworkService( - config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", - }, - "pollInterval": 1, - "pollTimeout": 2, + return AzureNetworkService(config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", }, - global_config={ - "deploymentName": "TEST_DEPLOYMENT-VNET", - "vnetName": "test-vnet", # Should come from the upper-level config - }, - parent=azure_auth_service, - ) + "pollInterval": 1, + "pollTimeout": 2 + }, global_config={ + "deploymentName": "TEST_DEPLOYMENT-VNET", + "vnetName": "test-vnet", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture @@ -73,60 +65,44 @@ def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: """ Creates a dummy Azure VM service for tests that require it. """ - return AzureVMService( - config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", - }, - "pollInterval": 1, - "pollTimeout": 2, - }, - global_config={ - "deploymentName": "TEST_DEPLOYMENT-VM", - "vmName": "test-vm", # Should come from the upper-level config + return AzureVMService(config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + "pollInterval": 1, + "pollTimeout": 2 + }, global_config={ + "deploymentName": "TEST_DEPLOYMENT-VM", + "vmName": "test-vm", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture -def azure_vm_service_remote_exec_only( - azure_auth_service: AzureAuthService, -) -> AzureVMService: +def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> AzureVMService: """ Creates a dummy Azure VM service with no deployment template. """ - return AzureVMService( - config={ - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "pollInterval": 1, - "pollTimeout": 2, - }, - global_config={ - "vmName": "test-vm", # Should come from the upper-level config - }, - parent=azure_auth_service, - ) + return AzureVMService(config={ + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "pollInterval": 1, + "pollTimeout": 2, + }, global_config={ + "vmName": "test-vm", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture -def azure_fileshare( - config_persistence_service: ConfigPersistenceService, -) -> AzureFileShareService: +def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: """ Creates a dummy AzureFileShareService for tests that require it. """ with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): - return AzureFileShareService( - config={ - "storageAccountName": "TEST_ACCOUNT_NAME", - "storageFileShareName": "TEST_FS_NAME", - "storageAccountKey": "TEST_ACCOUNT_KEY", - }, - global_config={}, - parent=config_persistence_service, - ) + return AzureFileShareService(config={ + "storageAccountName": "TEST_ACCOUNT_NAME", + "storageFileShareName": "TEST_FS_NAME", + "storageAccountKey": "TEST_ACCOUNT_KEY" + }, global_config={}, parent=config_persistence_service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index fb1c4ee39b..b9474f0709 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -20,24 +20,16 @@ class MockAuthService(Service, SupportsAuth): A collection Service functions for mocking authentication ops. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.get_access_token, - self.get_auth_headers, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.get_access_token, + self.get_auth_headers, + ]) ) def get_access_token(self) -> str: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 79f8c608c2..1a026966a8 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -21,30 +21,21 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): A collection Service functions for mocking file share ops. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) self._upload: List[Tuple[str, str]] = [] self._download: List[Tuple[str, str]] = [] - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: self._upload.append((local_path, remote_path)) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index 6bf9fc8d05..e6169d9f93 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -20,13 +20,10 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): Mock Network service for testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock network services proxy. @@ -41,19 +38,13 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - name: mock_operation - for name in ( - # SupportsNetworkProvisioning: - "provision_network", - "deprovision_network", - "wait_network_deployment", - ) - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + name: mock_operation for name in ( + # SupportsNetworkProvisioning: + "provision_network", + "deprovision_network", + "wait_network_deployment", + ) + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index 38d759f53c..ee99251c64 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -18,13 +18,10 @@ class MockRemoteExecService(Service, SupportsRemoteExec): Mock remote script execution service. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock remote exec service. @@ -39,14 +36,9 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - "remote_exec": mock_operation, - "get_remote_exec_results": mock_operation, - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + "remote_exec": mock_operation, + "get_remote_exec_results": mock_operation, + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index 3ae13cf6a6..a44edaf080 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -20,13 +20,10 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, Supports Mock VM service for testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock VM services proxy. @@ -41,29 +38,23 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - name: mock_operation - for name in ( - # SupportsHostProvisioning: - "wait_host_deployment", - "provision_host", - "deprovision_host", - "deallocate_host", - # SupportsHostOps: - "start_host", - "stop_host", - "restart_host", - "wait_host_operation", - # SupportsOsOps: - "shutdown", - "reboot", - "wait_os_operation", - ) - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + name: mock_operation for name in ( + # SupportsHostProvisioning: + "wait_host_deployment", + "provision_host", + "deprovision_host", + "deallocate_host", + # SupportsHostOps: + "start_host", + "stop_host", + "restart_host", + "wait_host_operation", + # SupportsOsOps: + "shutdown", + "reboot", + "wait_os_operation", + ) + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index c893adfd4a..e0060d8047 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -17,9 +17,9 @@ # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 -SSH_TEST_SERVER_NAME = "ssh-server" -ALT_TEST_SERVER_NAME = "alt-server" -REBOOT_TEST_SERVER_NAME = "reboot-server" +SSH_TEST_SERVER_NAME = 'ssh-server' +ALT_TEST_SERVER_NAME = 'alt-server' +REBOOT_TEST_SERVER_NAME = 'reboot-server' @dataclass @@ -42,12 +42,8 @@ def get_port(self, uncached: bool = False) -> int: Note: this value can change when the service restarts so we can't rely on the DockerServices. """ if self._port is None or uncached: - port_cmd = run( - f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", - shell=True, - check=True, - capture_output=True, - ) + port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", + shell=True, check=True, capture_output=True) self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) return self._port @@ -72,9 +68,7 @@ def to_connect_params(self, uncached: bool = False) -> dict: } -def wait_docker_service_socket( - docker_services: DockerServices, hostname: str, port: int -) -> None: +def wait_docker_service_socket(docker_services: DockerServices, hostname: str, port: int) -> None: """Wait until a docker service is ready.""" docker_services.wait_until_responsive( check=lambda: check_socket(hostname, port), diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 8b28856396..6f05fe953b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -30,28 +30,26 @@ # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = "host.docker.internal" +HOST_DOCKER_NAME = 'host.docker.internal' @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): # On Linux, if we're running in a docker container, we can use the # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME # Docker (Desktop) for Windows (WSL2) uses a special networking magic # to refer to the host machine as `localhost` when exposing ports. # In all other cases, assume we're executing directly inside conda on the host. - return "localhost" + return 'localhost' @pytest.fixture(scope="session") -def ssh_test_server( - ssh_test_server_hostname: str, - docker_compose_project_name: str, - locked_docker_services: DockerServices, -) -> Generator[SshTestServerInfo, None, None]: +def ssh_test_server(ssh_test_server_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: """ Fixture for getting the ssh test server services setup via docker-compose using pytest-docker. @@ -68,37 +66,23 @@ def ssh_test_server( compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, hostname=ssh_test_server_hostname, - username="root", - id_rsa_path=id_rsa_file.name, - ) - wait_docker_service_socket( - locked_docker_services, - ssh_test_server_info.hostname, - ssh_test_server_info.get_port(), - ) + username='root', + id_rsa_path=id_rsa_file.name) + wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" - cmd = run( - docker_cp_cmd.split(), - check=True, - cwd=os.path.dirname(__file__), - capture_output=True, - text=True, - ) + cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) if cmd.returncode != 0: - raise RuntimeError( - f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " - + f"[return={cmd.returncode}]: {str(cmd.stderr)}" - ) + raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " + + f"[return={cmd.returncode}]: {str(cmd.stderr)}") os.chmod(id_rsa_file.name, 0o600) yield ssh_test_server_info # NamedTempFile deleted on context exit @pytest.fixture(scope="session") -def alt_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices -) -> SshTestServerInfo: +def alt_test_server(ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -111,20 +95,14 @@ def alt_test_server( service_name=ALT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path, - ) - wait_docker_service_socket( - locked_docker_services, - alt_test_server_info.hostname, - alt_test_server_info.get_port(), - ) + id_rsa_path=ssh_test_server.id_rsa_path) + wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) return alt_test_server_info @pytest.fixture(scope="session") -def reboot_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices -) -> SshTestServerInfo: +def reboot_test_server(ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -137,13 +115,8 @@ def reboot_test_server( service_name=REBOOT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path, - ) - wait_docker_service_socket( - locked_docker_services, - reboot_test_server_info.hostname, - reboot_test_server_info.get_port(), - ) + id_rsa_path=ssh_test_server.id_rsa_path) + wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port()) return reboot_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index e3b9c85746..f2bbbe4b8a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -52,9 +52,8 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker -def test_ssh_fileshare_single_file( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -67,7 +66,7 @@ def test_ssh_fileshare_single_file( lines = [line + "\n" for line in lines] # 1. Write a local file and upload it. - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines(lines) temp_file.flush() temp_file.close() @@ -79,7 +78,7 @@ def test_ssh_fileshare_single_file( ) # 2. Download the remote file and compare the contents. - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.close() ssh_fileshare_service.download( params=config, @@ -87,15 +86,14 @@ def test_ssh_fileshare_single_file( local_path=temp_file.name, ) # Download will replace the inode at that name, so we need to reopen the file. - with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == lines @requires_docker -def test_ssh_fileshare_recursive( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -115,17 +113,14 @@ def test_ssh_fileshare_recursive( "bar", ], } - files_lines = { - path: [line + "\n" for line in lines] - for (path, lines) in files_lines.items() - } + files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: # Setup the directory structure. - for file_path, lines in files_lines.items(): + for (file_path, lines) in files_lines.items(): path = Path(tempdir1, file_path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode="w+t", encoding="utf-8") as temp_file: + with open(path, mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines(lines) temp_file.flush() assert os.path.getsize(path) > 0 @@ -148,22 +143,19 @@ def test_ssh_fileshare_recursive( # Compare both. # Note: remote dir name is appended to target. - assert are_dir_trees_equal( - tempdir1, path_join(tempdir2, basename(remote_file_path)) - ) + assert are_dir_trees_equal(tempdir1, path_join(tempdir2, basename(remote_file_path))) @requires_docker -def test_ssh_fileshare_download_file_dne( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() canary_str = "canary" - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines([canary_str]) temp_file.flush() temp_file.close() @@ -174,22 +166,20 @@ def test_ssh_fileshare_download_file_dne( remote_path="/tmp/file-dne.txt", local_path=temp_file.name, ) - with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == [canary_str] @requires_docker -def test_ssh_fileshare_upload_file_dne( - ssh_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - ssh_fileshare_service: SshFileShareService, -) -> None: +def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file upload that doesn't exist.""" with ssh_host_service, ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() - path = "/tmp/upload-file-src-dne.txt" + path = '/tmp/upload-file-src-dne.txt' with pytest.raises(OSError): ssh_fileshare_service.upload( params=config, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 6cea52a102..4c8e5e0c66 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -27,11 +27,9 @@ @requires_docker -def test_ssh_service_remote_exec( - ssh_test_server: SshTestServerInfo, - alt_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, -) -> None: +def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService) -> None: """ Test the SshHostService remote_exec. @@ -44,11 +42,7 @@ def test_ssh_service_remote_exec( connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - connection_client = ( - ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( - connection_id - ) - ) + connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) assert connection_client is None (status, results_info) = ssh_host_service.remote_exec( @@ -63,9 +57,7 @@ def test_ssh_service_remote_exec( assert results["stdout"].strip() == SSH_TEST_SERVER_NAME # Check that the client caching is behaving as expected. - connection, client = ( - ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] - ) + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection is not None assert connection._username == ssh_test_server.username assert connection._host == ssh_test_server.hostname @@ -99,15 +91,13 @@ def test_ssh_service_remote_exec( }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) - assert status.is_failed() # should retain exit code from "false" + assert status.is_failed() # should retain exit code from "false" stdout = str(results["stdout"]) assert stdout.splitlines() == [ "BAR=bar", "UNUSED=", ] - connection, client = ( - ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] - ) + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection._local_port == local_port # Close the connection (gracefully) @@ -124,7 +114,7 @@ def test_ssh_service_remote_exec( config=config, # Also test interacting with environment_variables. env_params={ - "FOO": "foo", + 'FOO': 'foo', }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) @@ -137,21 +127,17 @@ def test_ssh_service_remote_exec( "BAZ=", ] # Make sure it looks like we reconnected. - connection, client = ( - ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] - ) + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection._local_port != local_port # Make sure the cache is cleaned up on context exit. assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 -def check_ssh_service_reboot( - docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - graceful: bool, -) -> None: +def check_ssh_service_reboot(docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool) -> None: """ Check the SshHostService reboot operation. """ @@ -160,14 +146,12 @@ def check_ssh_service_reboot( # Also, it may cause issues with other parallel unit tests, so we run it as # a part of the same unit test for now. with ssh_host_service: - reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config( - uncached=True - ) + reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) (status, results_info) = ssh_host_service.remote_exec( script=[ 'echo "sleeping..."', - "sleep 30", - 'echo "should not reach this point"', + 'sleep 30', + 'echo "should not reach this point"' ], config=reboot_test_srv_ssh_svc_conf, env_params={}, @@ -177,14 +161,11 @@ def check_ssh_service_reboot( time.sleep(1) # Now try to restart the server. - (status, reboot_results_info) = ssh_host_service.reboot( - params=reboot_test_srv_ssh_svc_conf, force=not graceful - ) + (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf, + force=not graceful) assert status.is_pending() - (status, reboot_results_info) = ssh_host_service.wait_os_operation( - reboot_results_info - ) + (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) # NOTE: reboot/shutdown ops mostly return FAILED, even though the reboot succeeds. _LOG.debug("reboot status: %s: %s", status, reboot_results_info) @@ -202,34 +183,19 @@ def check_ssh_service_reboot( time.sleep(1) # try to reconnect and see if the port changed try: - run_res = run( - "docker ps | grep mlos_bench-test- | grep reboot", - shell=True, - capture_output=True, - check=False, - ) + run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False) print(run_res.stdout.decode()) print(run_res.stderr.decode()) - reboot_test_srv_ssh_svc_conf_new = ( - reboot_test_server.to_ssh_service_config(uncached=True) - ) - if ( - reboot_test_srv_ssh_svc_conf_new["ssh_port"] - != reboot_test_srv_ssh_svc_conf["ssh_port"] - ): + reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True) + if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]: break except CalledProcessError as ex: _LOG.info("Failed to check port for reboot test server: %s", ex) - assert ( - reboot_test_srv_ssh_svc_conf_new["ssh_port"] - != reboot_test_srv_ssh_svc_conf["ssh_port"] - ) + assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"] - wait_docker_service_socket( - docker_services, - reboot_test_server.hostname, - reboot_test_srv_ssh_svc_conf_new["ssh_port"], - ) + wait_docker_service_socket(docker_services, + reboot_test_server.hostname, + reboot_test_srv_ssh_svc_conf_new["ssh_port"]) (status, results_info) = ssh_host_service.remote_exec( script=["hostname"], @@ -242,18 +208,12 @@ def check_ssh_service_reboot( @requires_docker -def test_ssh_service_reboot( - locked_docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, -) -> None: +def test_ssh_service_reboot(locked_docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService) -> None: """ Test the SshHostService reboot operation. """ # Grouped together to avoid parallel runner interactions. - check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=True - ) - check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=False - ) + check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) + check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index b8e489b030..7bee929fea 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -35,9 +35,7 @@ # We replaced pytest-lazy-fixture with pytest-lazy-fixtures: # https://github.com/TvoroG/pytest-lazy-fixture/issues/65 if version("pytest-lazy-fixture"): - raise UserWarning( - "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it." - ) + raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.") except PackageNotFoundError: # OK: pytest-lazy-fixture not installed pass @@ -45,16 +43,12 @@ @requires_docker @requires_ssh -@pytest.mark.parametrize( - ["ssh_test_server_info", "server_name"], - [ - (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), - (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), - ], -) -def test_ssh_service_test_infra( - ssh_test_server_info: SshTestServerInfo, server_name: str -) -> None: +@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), +]) +def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, + server_name: str) -> None: """Check for the pytest-docker ssh test infra.""" assert ssh_test_server_info.service_name == server_name @@ -63,18 +57,17 @@ def test_ssh_service_test_infra( local_port = ssh_test_server_info.get_port() assert check_socket(ip_addr, local_port) - ssh_cmd = ( - "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " - + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " + ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + f"-p {local_port} {ssh_test_server_info.hostname} hostname" - ) - cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True) + cmd = run(ssh_cmd.split(), + capture_output=True, + text=True, + check=True) assert cmd.stdout.strip() == server_name -@pytest.mark.filterwarnings( - "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" -) +@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. @@ -100,43 +93,31 @@ def test_ssh_service_context_handler() -> None: time.sleep(0.25) assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - ssh_fileshare_service = SshFileShareService( - config={}, global_config={}, parent=None - ) + ssh_fileshare_service = SshFileShareService(config={}, global_config={}, parent=None) assert ssh_fileshare_service assert not ssh_fileshare_service._in_context with ssh_fileshare_service: assert ssh_fileshare_service._in_context assert ssh_host_service._in_context - assert ( - SshService._EVENT_LOOP_CONTEXT._event_loop_thread - is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread + assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread - ) - assert ( - SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE - is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE - ) assert not ssh_fileshare_service._in_context # And that instance should be unusable after we are outside the context. - with pytest.raises( - AssertionError - ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = ssh_fileshare_service._run_coroutine( - asyncio.sleep(0.1, result="foo") - ) - raise ValueError( - f"Future should not have been available to wait on {future.result()}" - ) + with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo')) + raise ValueError(f"Future should not have been available to wait on {future.result()}") # The background thread should remain running since we have another context still open. assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None -if __name__ == "__main__": +if __name__ == '__main__': # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py index 31daec07c3..463879634f 100644 --- a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py +++ b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py @@ -40,10 +40,7 @@ def test_service_method_register_without_constructor() -> None: # somehow having it in a different scope makes a difference if isinstance(mixin_service, SupportsSomeMethod): assert mixin_service.some_method() == f"{some_base_service}: base.some_method" - assert ( - mixin_service.some_other_method() - == f"{some_base_service}: base.some_other_method" - ) + assert mixin_service.some_other_method() == f"{some_base_service}: base.some_other_method" # register the child service mixin_service.register(some_child_service.export()) @@ -51,9 +48,6 @@ def test_service_method_register_without_constructor() -> None: assert mixin_service._services == {some_child_service} # check that the inheritance works as expected assert mixin_service.some_method() == f"{some_child_service}: child.some_method" - assert ( - mixin_service.some_other_method() - == f"{some_child_service}: base.some_other_method" - ) + assert mixin_service.some_other_method() == f"{some_child_service}: base.some_other_method" else: assert False diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 7b859e79ba..2c16df65c4 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -18,12 +18,8 @@ exp_no_tunables_storage = sql_storage_fixtures.exp_no_tunables_storage mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials -exp_no_tunables_storage_with_trials = ( - sql_storage_fixtures.exp_no_tunables_storage_with_trials -) -mixed_numerics_exp_storage_with_trials = ( - sql_storage_fixtures.mixed_numerics_exp_storage_with_trials -) +exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials +mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials exp_data = sql_storage_fixtures.exp_data exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 852155a8c6..8159043be1 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -22,32 +22,23 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) assert exp.objectives == exp_storage.opt_targets -def test_exp_data_root_env_config( - exp_storage: Storage.Experiment, exp_data: ExperimentData -) -> None: +def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: """Tests the root_env_config property of ExperimentData""" # pylint: disable=protected-access - assert exp_data.root_env_config == ( - exp_storage._root_env_config, - exp_storage._git_repo, - exp_storage._git_commit, - ) + assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) -def test_exp_trial_data_objectives( - storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups -) -> None: +def test_exp_trial_data_objectives(storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Start a new trial and check the storage for the trial data. """ - trial_opt_new = exp_storage.new_trial( - tunable_groups, - config={ - "opt_target": "some-other-target", - "opt_direction": "max", - }, - ) + trial_opt_new = exp_storage.new_trial(tunable_groups, config={ + "opt_target": "some-other-target", + "opt_direction": "max", + }) assert trial_opt_new.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_new.trial_id, @@ -55,13 +46,10 @@ def test_exp_trial_data_objectives( "opt_direction": "max", } - trial_opt_old = exp_storage.new_trial( - tunable_groups, - config={ - "opt_target": "back-compat", - # "opt_direction": "max", # missing - }, - ) + trial_opt_old = exp_storage.new_trial(tunable_groups, config={ + "opt_target": "back-compat", + # "opt_direction": "max", # missing + }) assert trial_opt_old.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_old.trial_id, @@ -78,9 +66,7 @@ def test_exp_trial_data_objectives( } -def test_exp_data_results_df( - exp_data: ExperimentData, tunable_groups: TunableGroups -) -> None: +def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: """Tests the results_df property of ExperimentData""" results_df = exp_data.results_df expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT @@ -88,20 +74,12 @@ def test_exp_data_results_df( assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT assert len(results_df["trial_id"].unique()) == expected_trials_count obj_target = next(iter(exp_data.objectives)) - assert ( - len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) - == expected_trials_count - ) + assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count (tunable, _covariant_group) = next(iter(tunable_groups)) - assert ( - len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) - == expected_trials_count - ) + assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count -def test_exp_data_tunable_config_trial_group_id_in_results_df( - exp_data: ExperimentData, -) -> None: +def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: """ Tests the tunable_config_trial_group_id property of ExperimentData.results_df @@ -136,21 +114,15 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: This tests bulk loading of the tunable_config_trial_groups. """ # Should be keyed by config_id. - assert list(exp_data.tunable_config_trial_groups.keys()) == list( - range(1, CONFIG_COUNT + 1) - ) + assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [ - config_trial_group.tunable_config_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [config_trial_group.tunable_config_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT + 1)) # And the tunable_config_trial_group_id should also match the minimum trial_id. - assert [ - config_trial_group.tunable_config_trial_group_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list( - range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT) - ) + assert [config_trial_group.tunable_config_trial_group_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: @@ -158,9 +130,9 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [ - config.tunable_config_id for config in exp_data.tunable_configs.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [config.tunable_config_id + for config in exp_data.tunable_configs.values() + ] == list(range(1, CONFIG_COUNT + 1)) def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index 91920190de..d0a5edc694 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -37,11 +37,9 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial and check that it is pending. """ @@ -52,16 +50,14 @@ def test_exp_trial_pending( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_many( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_pending_many(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start THREE trials and check that both are pending. """ - config1 = tunable_groups.copy().assign({"idle": "mwait"}) - config2 = tunable_groups.copy().assign({"idle": "noidle"}) + config1 = tunable_groups.copy().assign({'idle': 'mwait'}) + config2 = tunable_groups.copy().assign({'idle': 'noidle'}) trial_ids = { exp_storage.new_trial(config1).trial_id, exp_storage.new_trial(config2).trial_id, @@ -76,11 +72,9 @@ def test_exp_trial_pending_many( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_fail( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial, fail it, and and check that it is NOT pending. """ @@ -91,11 +85,9 @@ def test_exp_trial_pending_fail( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_success( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_success(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial, finish it successfully, and and check that it is NOT pending. """ @@ -106,39 +98,31 @@ def test_exp_trial_success( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_categ( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_update_categ(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Update the trial with multiple metrics, some of which are categorical. """ trial = exp_storage.new_trial(tunable_groups) - trial.update( - Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"} - ) + trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) assert exp_storage.load() == ( [trial.trial_id], - [ - { - "idle": "halt", - "kernel_sched_latency_ns": "2000000", - "kernel_sched_migration_cost_ns": "-1", - "vmSize": "Standard_B4ms", - } - ], + [{ + 'idle': 'halt', + 'kernel_sched_latency_ns': '2000000', + 'kernel_sched_migration_cost_ns': '-1', + 'vmSize': 'Standard_B4ms' + }], [{"score": "99.9", "benchmark": "test"}], - [Status.SUCCEEDED], + [Status.SUCCEEDED] ) @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_twice( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_update_twice(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Update the trial status twice and receive an error. """ @@ -149,11 +133,9 @@ def test_exp_trial_update_twice( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_3( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo], -) -> None: +def test_exp_trial_pending_3(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. Check that one is still pending another one can be loaded into the optimizer. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 5d56a3e195..7e346a5ccc 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -36,7 +36,7 @@ def storage() -> SqlStorage: "drivername": "sqlite", "database": ":memory:", # "database": "mlos_bench.pytest.db", - }, + } ) @@ -106,9 +106,7 @@ def mixed_numerics_exp_storage( assert not exp._in_context -def _dummy_run_exp( - exp: SqlStorage.Experiment, tunable_name: Optional[str] -) -> SqlStorage.Experiment: +def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: """ Generates data by doing a simulated run of the given experiment. """ @@ -121,68 +119,47 @@ def _dummy_run_exp( (tunable_min, tunable_max) = tunable.range tunable_range = tunable_max - tunable_min rand_seed(SEED) - opt = MockOptimizer( - tunables=exp.tunables, - config={ - "seed": SEED, - # This should be the default, so we leave it omitted for now to test the default. - # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) - # "start_with_defaults": True, - }, - ) + opt = MockOptimizer(tunables=exp.tunables, config={ + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) + # "start_with_defaults": True, + }) assert opt.start_with_defaults for config_i in range(CONFIG_COUNT): tunables = opt.suggest() for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT): - trial = exp.new_trial( - tunables=tunables.copy(), - config={ - "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(exp.opt_targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - }, - }, - ) + trial = exp.new_trial(tunables=tunables.copy(), config={ + "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(exp.opt_targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + } + }) if exp.tunables: assert trial.tunable_config_id == config_i + 1 else: assert trial.tunable_config_id == 1 if tunable_name: - tunable_value = float( - tunables.get_tunable(tunable_name)[0].numerical_value - ) - tunable_value_norm = ( - base_score * (tunable_value - tunable_min) / tunable_range - ) + tunable_value = float(tunables.get_tunable(tunable_name)[0].numerical_value) + tunable_value_norm = base_score * (tunable_value - tunable_min) / tunable_range else: tunable_value_norm = 0 timestamp = datetime.now(UTC) - trial.update_telemetry( - status=Status.RUNNING, - timestamp=timestamp, - metrics=[ - (timestamp, "some-metric", tunable_value_norm + random() / 100), - ], - ) - trial.update( - Status.SUCCEEDED, - timestamp, - metrics={ - # Give some variance on the score. - # And some influence from the tunable value. - "score": tunable_value_norm - + random() / 100 - }, - ) + trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[ + (timestamp, "some-metric", tunable_value_norm + random() / 100), + ]) + trial.update(Status.SUCCEEDED, timestamp, metrics={ + # Give some variance on the score. + # And some influence from the tunable value. + "score": tunable_value_norm + random() / 100 + }) return exp @pytest.fixture -def exp_storage_with_trials( - exp_storage: SqlStorage.Experiment, -) -> SqlStorage.Experiment: +def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -190,9 +167,7 @@ def exp_storage_with_trials( @pytest.fixture -def exp_no_tunables_storage_with_trials( - exp_no_tunables_storage: SqlStorage.Experiment, -) -> SqlStorage.Experiment: +def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -201,9 +176,7 @@ def exp_no_tunables_storage_with_trials( @pytest.fixture -def mixed_numerics_exp_storage_with_trials( - mixed_numerics_exp_storage: SqlStorage.Experiment, -) -> SqlStorage.Experiment: +def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -212,9 +185,7 @@ def mixed_numerics_exp_storage_with_trials( @pytest.fixture -def exp_data( - storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData. """ @@ -222,9 +193,7 @@ def exp_data( @pytest.fixture -def exp_no_tunables_data( - storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData with no tunable configs. """ @@ -232,9 +201,7 @@ def exp_no_tunables_data( @pytest.fixture -def mixed_numerics_exp_data( - storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData with mixed numerical tunable types. """ diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index e8c4d38a9a..ba965ed3c6 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -13,9 +13,8 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_exp_trial_pending( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups -) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Schedule a trial and check that it is pending and has the right configuration. """ @@ -32,14 +31,13 @@ def test_exp_trial_pending( } -def test_exp_trial_configs( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups -) -> None: +def test_exp_trial_configs(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Start multiple trials with two different configs and check that we store only two config objects in the DB. """ - config1 = tunable_groups.copy().assign({"idle": "mwait"}) + config1 = tunable_groups.copy().assign({'idle': 'mwait'}) trials1 = [ exp_storage.new_trial(config1), exp_storage.new_trial(config1), @@ -48,7 +46,7 @@ def test_exp_trial_configs( assert trials1[0].tunable_config_id == trials1[1].tunable_config_id assert trials1[0].tunable_config_id == trials1[2].tunable_config_id - config2 = tunable_groups.copy().assign({"idle": "halt"}) + config2 = tunable_groups.copy().assign({'idle': 'halt'}) trials2 = [ exp_storage.new_trial(config2), exp_storage.new_trial(config2), @@ -65,10 +63,7 @@ def test_exp_trial_configs( ] assert len(pending_ids) == 6 assert len(set(pending_ids)) == 2 - assert set(pending_ids) == { - trials1[0].tunable_config_id, - trials2[0].tunable_config_id, - } + assert set(pending_ids) == {trials1[0].tunable_config_id, trials2[0].tunable_config_id} def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index c56efa0031..04f4f18ae3 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -22,9 +22,8 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: return set(t.trial_id for t in trials) -def test_schedule_trial( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups -) -> None: +def test_schedule_trial(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Schedule several trials for future execution and retrieve them later at certain timestamps. """ @@ -40,16 +39,13 @@ def test_schedule_trial( # Schedule 1 hour in the future: trial_1h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr, config) # Schedule 2 hours in the future: - trial_2h = exp_storage.new_trial( - tunable_groups, timestamp + timedelta_1hr * 2, config - ) + trial_2h = exp_storage.new_trial(tunable_groups, timestamp + timedelta_1hr * 2, config) # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1min, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -57,8 +53,7 @@ def test_schedule_trial( # Get trials scheduled to run within the next 1 hour: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -67,8 +62,7 @@ def test_schedule_trial( # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -90,8 +84,7 @@ def test_schedule_trial( # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) assert pending_ids == { trial_1h.trial_id, trial_2h.trial_id, @@ -99,8 +92,7 @@ def test_schedule_trial( # Get trials scheduled to run OR running within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -109,15 +101,11 @@ def test_schedule_trial( } # Mark some trials completed after 2 minutes: - trial_now1.update( - Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0} - ) + trial_now1.update(Status.SUCCEEDED, timestamp + timedelta_1min * 2, metrics={"score": 1.0}) trial_now2.update(Status.FAILED, timestamp + timedelta_1min * 2) # Another one completes after 2 hours: - trial_1h.update( - Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0} - ) + trial_1h.update(Status.SUCCEEDED, timestamp + timedelta_1hr * 2, metrics={"score": 1.0}) # Check that three trials have completed so far: (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load() @@ -126,9 +114,7 @@ def test_schedule_trial( assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] # Get only trials completed after trial_now2: - (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( - last_trial_id=trial_now2.trial_id - ) + (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) assert trial_ids == [trial_1h.trial_id] assert len(trial_configs) == len(trial_scores) == 1 assert trial_status == [Status.SUCCEEDED] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index e1f033fae9..855c6cd861 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -20,9 +20,7 @@ # pylint: disable=redefined-outer-name -def zoned_telemetry_data( - zone_info: Optional[tzinfo], -) -> List[Tuple[datetime, str, Any]]: +def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, str, Any]]: """ Mock telemetry data for the trial. @@ -33,21 +31,18 @@ def zoned_telemetry_data( """ timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) - return sorted( - [ - (timestamp1, "cpu_load", 10.1), - (timestamp1, "memory", 20), - (timestamp1, "setup", "prod"), - (timestamp2, "cpu_load", 30.1), - (timestamp2, "memory", 40), - (timestamp2, "setup", "prod"), - ] - ) + return sorted([ + (timestamp1, "cpu_load", 10.1), + (timestamp1, "memory", 20), + (timestamp1, "setup", "prod"), + (timestamp2, "cpu_load", 30.1), + (timestamp2, "memory", 40), + (timestamp2, "setup", "prod"), + ]) -def _telemetry_str( - data: List[Tuple[datetime, str, Any]] -) -> List[Tuple[datetime, str, Optional[str]]]: +def _telemetry_str(data: List[Tuple[datetime, str, Any]] + ) -> List[Tuple[datetime, str, Optional[str]]]: """ Convert telemetry values to strings. """ @@ -56,12 +51,10 @@ def _telemetry_str( @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry( - storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo], -) -> None: +def test_update_telemetry(storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() and load_telemetry() methods work. """ @@ -69,9 +62,7 @@ def test_update_telemetry( trial = exp_storage.new_trial(tunable_groups) assert exp_storage.load_telemetry(trial.trial_id) == [] - trial.update_telemetry( - Status.RUNNING, datetime.now(origin_zone_info), telemetry_data - ) + trial.update_telemetry(Status.RUNNING, datetime.now(origin_zone_info), telemetry_data) assert exp_storage.load_telemetry(trial.trial_id) == _telemetry_str(telemetry_data) # Also check that the TrialData telemetry looks right. @@ -82,11 +73,9 @@ def test_update_telemetry( @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry_twice( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo], -) -> None: +def test_update_telemetry_twice(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() call is idempotent. """ diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index a3333acd2b..3b57222822 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -10,9 +10,8 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_trial_data_tunable_config_data( - exp_data: ExperimentData, tunable_groups: TunableGroups -) -> None: +def test_trial_data_tunable_config_data(exp_data: ExperimentData, + tunable_groups: TunableGroups) -> None: """ Check expected return values for TunableConfigData. """ @@ -30,18 +29,16 @@ def test_trial_metadata(exp_data: ExperimentData) -> None: """ Check expected return values for TunableConfigData metadata. """ - assert exp_data.objectives == {"score": "min"} - for trial_id, trial in exp_data.trials.items(): + assert exp_data.objectives == {'score': 'min'} + for (trial_id, trial) in exp_data.trials.items(): assert trial.metadata_dict == { - "opt_target_0": "score", - "opt_direction_0": "min", - "trial_number": trial_id, + 'opt_target_0': 'score', + 'opt_direction_0': 'min', + 'trial_number': trial_id, } -def test_trial_data_no_tunables_config_data( - exp_no_tunables_data: ExperimentData, -) -> None: +def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None: """ Check expected return values for TunableConfigData. """ @@ -51,14 +48,13 @@ def test_trial_data_no_tunables_config_data( def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, - mixed_numerics_tunable_groups: TunableGroups, -) -> None: + mixed_numerics_exp_data: ExperimentData, + mixed_numerics_tunable_groups: TunableGroups) -> None: """ Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) config = trial.tunable_config.config_dict - for tunable, _group in mixed_numerics_tunable_groups: + for (tunable, _group) in mixed_numerics_tunable_groups: assert isinstance(config[tunable.name], tunable.dtype) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index 987b1a75b2..d08b26e92d 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -16,19 +16,10 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: trial_id = 1 trial = exp_data.trials[trial_id] tunable_config_trial_group = trial.tunable_config_trial_group - assert ( - tunable_config_trial_group.experiment_id - == exp_data.experiment_id - == trial.experiment_id - ) + assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id assert tunable_config_trial_group.tunable_config == trial.tunable_config - assert ( - tunable_config_trial_group - == next( - iter(tunable_config_trial_group.trials.values()) - ).tunable_config_trial_group - ) + assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: @@ -58,9 +49,7 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) # And so on ... -def test_tunable_config_trial_group_results_df( - exp_data: ExperimentData, tunable_groups: TunableGroups -) -> None: +def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 expected_group_id = 4 @@ -69,38 +58,15 @@ def test_tunable_config_trial_group_results_df( # We shouldn't have the results for the other configs, just this one. expected_count = CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_count - assert ( - len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) - == expected_count - ) + assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 - assert ( - len( - results_df[ - (results_df["tunable_config_trial_group_id"] == expected_group_id) - ] - ) - == expected_count - ) - assert ( - len( - results_df[ - (results_df["tunable_config_trial_group_id"] != expected_group_id) - ] - ) - == 0 - ) + assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count + assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 assert len(results_df["trial_id"].unique()) == expected_count obj_target = next(iter(exp_data.objectives)) - assert ( - len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) - == expected_count - ) + assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_count (tunable, _covariant_group) = next(iter(tunable_groups)) - assert ( - len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) - == expected_count - ) + assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_count def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: @@ -110,16 +76,8 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] trials = tunable_config_trial_group.trials assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT - assert all( - trial.tunable_config_trial_group.tunable_config_trial_group_id - == expected_group_id - for trial in trials.values() - ) - assert all( - trial.tunable_config_id == tunable_config_id - for trial in tunable_config_trial_group.trials.values() - ) - assert ( - exp_data.trials[expected_group_id] - == tunable_config_trial_group.trials[expected_group_id] - ) + assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id + for trial in trials.values()) + assert all(trial.tunable_config_id == tunable_config_id + for trial in tunable_config_trial_group.trials.values()) + assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index 2aba200955..fa947610da 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -24,9 +24,7 @@ ] -@pytest.mark.skipif( - sys.platform == "win32", reason="TZ environment variable is a UNIXism" -) +@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism") @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: @@ -47,6 +45,4 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: if cmd.returncode != 0: print(cmd.stdout.decode()) print(cmd.stderr.decode()) - raise AssertionError( - f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'" - ) + raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'") diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 8329b51bd0..822547b1da 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -119,26 +119,24 @@ def mixed_numerics_tunable_groups() -> TunableGroups: tunable_groups : TunableGroups A new TunableGroups object for testing. """ - tunables = TunableGroups( - { - "mix-numerics": { - "cost": 1, - "params": { - "int": { - "description": "An integer", - "type": "int", - "default": 0, - "range": [0, 100], - }, - "float": { - "description": "A float", - "type": "float", - "default": 0, - "range": [0, 1], - }, + tunables = TunableGroups({ + "mix-numerics": { + "cost": 1, + "params": { + "int": { + "description": "An integer", + "type": "int", + "default": 0, + "range": [0, 100], }, - }, - } - ) + "float": { + "description": "A float", + "type": "float", + "default": 0, + "range": [0, 1], + }, + } + }, + }) tunables.reset() return tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index 878471b59e..95de20d9b8 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -25,15 +25,12 @@ def tunable_categorical() -> Tunable: tunable : Tunable An instance of a categorical Tunable. """ - return Tunable( - "vmSize", - { - "description": "Azure VM size", - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - }, - ) + return Tunable("vmSize", { + "description": "Azure VM size", + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] + }) @pytest.fixture @@ -46,16 +43,13 @@ def tunable_int() -> Tunable: tunable : Tunable An instance of an integer Tunable. """ - return Tunable( - "kernel_sched_migration_cost_ns", - { - "description": "Cost of migrating the thread to another core", - "type": "int", - "default": 40000, - "range": [0, 500000], - "special": [-1], # Special value outside of the range - }, - ) + return Tunable("kernel_sched_migration_cost_ns", { + "description": "Cost of migrating the thread to another core", + "type": "int", + "default": 40000, + "range": [0, 500000], + "special": [-1] # Special value outside of the range + }) @pytest.fixture @@ -68,12 +62,9 @@ def tunable_float() -> Tunable: tunable : Tunable An instance of a float Tunable. """ - return Tunable( - "chaos_monkey_prob", - { - "description": "Probability of spontaneous VM shutdown", - "type": "float", - "default": 0.01, - "range": [0, 1], - }, - ) + return Tunable("chaos_monkey_prob", { + "description": "Probability of spontaneous VM shutdown", + "type": "float", + "default": 0.01, + "range": [0, 1] + }) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index e8b3e6b4cc..0e910f3761 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -38,7 +38,7 @@ def test_tunable_categorical_types() -> None: "values": ["a", "b", "c"], "default": "a", }, - }, + } } } tunable_groups = TunableGroups(tunable_params) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index b29c3a1b9e..58bb0368b1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -14,7 +14,6 @@ # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py - def test_tunable_int_size_props() -> None: """Test tunable int size properties""" tunable = Tunable( @@ -23,8 +22,7 @@ def test_tunable_int_size_props() -> None: "type": "int", "range": [1, 5], "default": 3, - }, - ) + }) assert tunable.span == 4 assert tunable.cardinality == 5 expected = [1, 2, 3, 4, 5] @@ -40,8 +38,7 @@ def test_tunable_float_size_props() -> None: "type": "float", "range": [1.5, 5], "default": 3, - }, - ) + }) assert tunable.span == 3.5 assert tunable.cardinality == np.inf assert tunable.quantized_values is None @@ -56,8 +53,7 @@ def test_tunable_categorical_size_props() -> None: "type": "categorical", "values": ["a", "b", "c"], "default": "a", - }, - ) + }) with pytest.raises(AssertionError): _ = tunable.span assert tunable.cardinality == 3 @@ -74,9 +70,8 @@ def test_tunable_quantized_int_size_props() -> None: "type": "int", "range": [100, 1000], "default": 100, - "quantization": 100, - }, - ) + "quantization": 100 + }) assert tunable.span == 900 assert tunable.cardinality == 10 expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] @@ -88,8 +83,12 @@ def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties""" tunable = Tunable( name="test", - config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1}, - ) + config={ + "type": "float", + "range": [0, 1], + "default": 0, + "quantization": .1 + }) assert tunable.span == 1 assert tunable.cardinality == 11 expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 407998b3a4..6a91b14016 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -28,7 +28,7 @@ def test_tunable_int_name_lt(tunable_int: Tunable) -> None: Tests that the __lt__ operator works as expected. """ tunable_int_2 = tunable_int.copy() - tunable_int_2._name = "aaa" # pylint: disable=protected-access + tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int @@ -38,8 +38,7 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: """ tunable_categorical_2 = tunable_categorical.copy() new_value = [ - x - for x in tunable_categorical.categories + x for x in tunable_categorical.categories if x != tunable_categorical.category and x is not None ][0] assert tunable_categorical.category is not None @@ -60,7 +59,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - }, + } ) tunable_dog = Tunable( name="same-name", @@ -68,7 +67,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": [None, "doggo"], "default": None, - }, + } ) assert tunable_dog < tunable_cat @@ -83,7 +82,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - }, + } ) tunable_int = Tunable( name="same-name", @@ -91,7 +90,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "int", "range": [1, 3], "default": 2, - }, + } ) assert tunable_cat < tunable_int @@ -102,7 +101,7 @@ def test_tunable_lt_different_object(tunable_int: Tunable) -> None: """ assert (tunable_int < "foo") is False with pytest.raises(TypeError): - assert "foo" < tunable_int # type: ignore[operator] + assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index d2ab29f27d..f2da3ba60e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -18,9 +18,7 @@ def test_tunable_name() -> None: """ with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names - Tunable( - name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0} - ) + Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) def test_categorical_required_params() -> None: @@ -36,7 +34,7 @@ def test_categorical_required_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_weights() -> None: @@ -52,7 +50,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.weights == [25, 25, 50] @@ -70,7 +68,7 @@ def test_categorical_weights_wrong_count() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_weights_wrong_values() -> None: @@ -87,7 +85,7 @@ def test_categorical_weights_wrong_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_wrong_params() -> None: @@ -104,7 +102,7 @@ def test_categorical_wrong_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_disallow_special_values() -> None: @@ -121,7 +119,7 @@ def test_categorical_disallow_special_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_tunable_disallow_repeats() -> None: @@ -129,50 +127,37 @@ def test_categorical_tunable_disallow_repeats() -> None: Disallow duplicate values in categorical tunables. """ with pytest.raises(ValueError): - Tunable( - name="test", - config={ - "type": "categorical", - "values": ["foo", "bar", "foo"], - "default": "foo", - }, - ) + Tunable(name='test', config={ + "type": "categorical", + "values": ["foo", "bar", "foo"], + "default": "foo", + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) -def test_numerical_tunable_disallow_null_default( - tunable_type: TunableValueTypeName, -) -> None: +def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None: """ Disallow null values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "default": None, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "default": None, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) -def test_numerical_tunable_disallow_out_of_range( - tunable_type: TunableValueTypeName, -) -> None: +def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None: """ Disallow out of range values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "default": 11, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "default": 11, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -181,15 +166,12 @@ def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> N Disallow values param for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "values": ["foo", "bar"], - "default": 0, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "values": ["foo", "bar"], + "default": 0, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -206,7 +188,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -223,7 +205,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(AssertionError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -240,7 +222,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -259,7 +241,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.special == [0] assert tunable.weights == [0.1] assert tunable.range_weight == 0.9 @@ -279,7 +261,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.quantization == 10 assert not tunable.is_log @@ -298,7 +280,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.is_log @@ -317,7 +299,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -337,7 +319,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok assert tunable.range_weight == 90 @@ -360,7 +342,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -379,7 +361,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -398,7 +380,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -416,7 +398,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -436,7 +418,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -454,7 +436,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_bad_type() -> None: @@ -470,4 +452,4 @@ def test_bad_type() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test_bad_type", config=config) + Tunable(name='test_bad_type', config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index e8817319ab..deffcb6a46 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -17,15 +17,14 @@ def test_categorical_distribution() -> None: Try to instantiate a categorical tunable with distribution specified. """ with pytest.raises(ValueError): - Tunable( - name="test", - config={ - "type": "categorical", - "values": ["foo", "bar", "baz"], - "distribution": {"type": "uniform"}, - "default": "foo", + Tunable(name='test', config={ + "type": "categorical", + "values": ["foo", "bar", "baz"], + "distribution": { + "type": "uniform" }, - ) + "default": "foo" + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -33,15 +32,14 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N """ Create a numeric Tunable with explicit uniform distribution. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "uniform"}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "uniform" }, - ) + "default": 0 + }) assert tunable.is_numerical assert tunable.distribution == "uniform" assert not tunable.distribution_params @@ -52,15 +50,18 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No """ Create a numeric Tunable with explicit Gaussian distribution specified. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "normal", + "params": { + "mu": 0, + "sigma": 1.0 + } }, - ) + "default": 0 + }) assert tunable.distribution == "normal" assert tunable.distribution_params == {"mu": 0, "sigma": 1.0} @@ -70,15 +71,18 @@ def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None """ Create a numeric Tunable with explicit Beta distribution specified. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "beta", + "params": { + "alpha": 2, + "beta": 5 + } }, - ) + "default": 0 + }) assert tunable.distribution == "beta" assert tunable.distribution_params == {"alpha": 2, "beta": 5} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index eb73b34d12..c6fb5670f0 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -10,9 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunable_group_indexing( - tunable_groups: TunableGroups, tunable_categorical: Tunable -) -> None: +def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: """ Check that various types of indexing work for the tunable group. """ @@ -22,9 +20,7 @@ def test_tunable_group_indexing( # NOTE: we reassign the tunable_categorical here since they come from # different fixtures so are technically different objects. - (tunable_categorical, covariant_group) = tunable_groups.get_tunable( - tunable_categorical.name - ) + (tunable_categorical, covariant_group) = tunable_groups.get_tunable(tunable_categorical.name) assert tunable_groups.get_tunable(tunable_categorical)[0] == tunable_categorical assert tunable_categorical in covariant_group @@ -44,9 +40,7 @@ def test_tunable_group_indexing( assert covariant_group[tunable_categorical.name] == tunable_categorical.value # Check that we can assign a new value by index. - new_value = [ - x for x in tunable_categorical.categories if x != tunable_categorical.value - ][0] + new_value = [x for x in tunable_categorical.categories if x != tunable_categorical.value][0] tunable_groups[tunable_categorical] = new_value assert tunable_groups[tunable_categorical] == new_value assert tunable_groups[tunable_categorical.name] == new_value diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 186de4acfa..55a485e951 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -14,4 +14,4 @@ def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: Check that the subgroup() method returns only a selection of tunable parameters. """ tunables = tunable_groups.subgroup(["provision"]) - assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"} + assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 0dfbdd2acd..73e3a12caa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -36,39 +36,37 @@ @pytest.mark.parametrize("param_type", ["int", "float"]) -@pytest.mark.parametrize( - "distr_name,distr_params", - [ - ("normal", {"mu": 0.0, "sigma": 1.0}), - ("beta", {"alpha": 2, "beta": 5}), - ("uniform", {}), - ], -) -def test_convert_numerical_distributions( - param_type: str, distr_name: DistributionName, distr_params: dict -) -> None: +@pytest.mark.parametrize("distr_name,distr_params", [ + ("normal", {"mu": 0.0, "sigma": 1.0}), + ("beta", {"alpha": 2, "beta": 5}), + ("uniform", {}), +]) +def test_convert_numerical_distributions(param_type: str, + distr_name: DistributionName, + distr_params: dict) -> None: """ Convert a numerical Tunable with explicit distribution to ConfigSpace. """ tunable_name = "x" - tunable_groups = TunableGroups( - { - "tunable_group": { - "cost": 1, - "params": { - tunable_name: { - "type": param_type, - "range": [0, 100], - "special": [-1, 0], - "special_weights": [0.1, 0.2], - "range_weight": 0.7, - "distribution": {"type": distr_name, "params": distr_params}, - "default": 0, - } - }, + tunable_groups = TunableGroups({ + "tunable_group": { + "cost": 1, + "params": { + tunable_name: { + "type": param_type, + "range": [0, 100], + "special": [-1, 0], + "special_weights": [0.1, 0.2], + "range_weight": 0.7, + "distribution": { + "type": distr_name, + "params": distr_params + }, + "default": 0 + } } } - ) + }) (tunable, _group) = tunable_groups.get_tunable(tunable_name) assert tunable.distribution == distr_name @@ -84,5 +82,5 @@ def test_convert_numerical_distributions( cs_param = space[tunable_name] assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) - for key, val in distr_params.items(): + for (key, val) in distr_params.items(): assert getattr(cs_param, key) == val diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index c92187a3e7..78e91fd25e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -38,23 +38,17 @@ def configuration_space() -> ConfigurationSpace: configuration_space : ConfigurationSpace A new ConfigurationSpace object for testing. """ - (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = ( - special_param_names("kernel_sched_migration_cost_ns") - ) - - spaces = ConfigurationSpace( - space={ - "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - "idle": ["halt", "mwait", "noidle"], - "kernel_sched_migration_cost_ns": (0, 500000), - kernel_sched_migration_cost_ns_special: [-1, 0], - kernel_sched_migration_cost_ns_type: [ - TunableValueKind.SPECIAL, - TunableValueKind.RANGE, - ], - "kernel_sched_latency_ns": (0, 1000000000), - } - ) + (kernel_sched_migration_cost_ns_special, + kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns") + + spaces = ConfigurationSpace(space={ + "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + "idle": ["halt", "mwait", "noidle"], + "kernel_sched_migration_cost_ns": (0, 500000), + kernel_sched_migration_cost_ns_special: [-1, 0], + kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE], + "kernel_sched_latency_ns": (0, 1000000000), + }) # NOTE: FLAML requires distribution to be uniform spaces["vmSize"].default_value = "Standard_B4ms" @@ -66,27 +60,18 @@ def configuration_space() -> ConfigurationSpace: spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5) spaces["kernel_sched_latency_ns"].default_value = 2000000 - spaces.add_condition( - EqualsCondition( - spaces[kernel_sched_migration_cost_ns_special], - spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.SPECIAL, - ) - ) - spaces.add_condition( - EqualsCondition( - spaces["kernel_sched_migration_cost_ns"], - spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.RANGE, - ) - ) + spaces.add_condition(EqualsCondition( + spaces[kernel_sched_migration_cost_ns_special], + spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL)) + spaces.add_condition(EqualsCondition( + spaces["kernel_sched_migration_cost_ns"], + spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE)) return spaces def _cmp_tunable_hyperparameter_categorical( - tunable: Tunable, space: ConfigurationSpace -) -> None: + tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if categorical Tunable and ConfigSpace Hyperparameter actually match. """ @@ -97,8 +82,7 @@ def _cmp_tunable_hyperparameter_categorical( def _cmp_tunable_hyperparameter_numerical( - tunable: Tunable, space: ConfigurationSpace -) -> None: + tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if integer Tunable and ConfigSpace Hyperparameter actually match. """ @@ -146,13 +130,12 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) - for tunable, _group in tunable_groups: + for (tunable, _group) in tunable_groups: _CMP_FUNC[tunable.type](tunable, space) def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace -) -> None: + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: """ Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index 2f7790602f..cbccd6bfe1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -19,14 +19,12 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: that don't exist in the TunableGroups object. """ with pytest.raises(KeyError): - tunable_groups.assign( - { - "vmSize": "Standard_B2ms", - "idle": "mwait", - "UnknownParam_1": 1, - "UnknownParam_2": "invalid-value", - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2ms", + "idle": "mwait", + "UnknownParam_1": 1, + "UnknownParam_2": "invalid-value" + }) def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: @@ -108,7 +106,7 @@ def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: Check str to int coercion. """ tunable_int.value = "10" - assert tunable_int.value == 10 # type: ignore[comparison-overlap] + assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special @@ -117,7 +115,7 @@ def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: Check str to float coercion. """ tunable_float.value = "0.5" - assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special @@ -151,12 +149,12 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) - categorical_tunable = Tunable(name="categorical_test", config=config) + categorical_tunable = Tunable(name='categorical_test', config=config) assert categorical_tunable assert categorical_tunable.category == "foo" categorical_tunable.value = None assert categorical_tunable.value is None - assert categorical_tunable.value != "None" + assert categorical_tunable.value != 'None' assert categorical_tunable.category is None @@ -167,7 +165,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): - tunable_int.numerical_value = None # type: ignore[assignment] + tunable_int.numerical_value = None # type: ignore[assignment] def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: @@ -177,7 +175,7 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): - tunable_float.numerical_value = None # type: ignore[assignment] + tunable_float.numerical_value = None # type: ignore[assignment] def test_tunable_assign_special(tunable_int: Tunable) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index cb41f7f7d8..672b16ab73 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -17,44 +17,42 @@ def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: tunables within each covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: - tunables_other = TunableGroups( - { - "kernel": { - "cost": 1, - "params": { - "kernel_sched_latency_ns": { - "type": "int", - "default": 2000000, - "range": [0, 1000000000], - }, - "kernel_sched_migration_cost_ns": { - "type": "int", - "default": -1, - "range": [0, 500000], - "special": [-1], - }, + tunables_other = TunableGroups({ + "kernel": { + "cost": 1, + "params": { + "kernel_sched_latency_ns": { + "type": "int", + "default": 2000000, + "range": [0, 1000000000] }, - }, - "boot": { - "cost": 300, - "params": { - "idle": { - "type": "categorical", - "default": "halt", - "values": ["halt", "mwait", "noidle"], - } - }, - }, - "provision": { - "cost": 1000, - "params": { - "vmSize": { - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - } - }, - }, - } - ) + "kernel_sched_migration_cost_ns": { + "type": "int", + "default": -1, + "range": [0, 500000], + "special": [-1] + } + } + }, + "boot": { + "cost": 300, + "params": { + "idle": { + "type": "categorical", + "default": "halt", + "values": ["halt", "mwait", "noidle"] + } + } + }, + "provision": { + "cost": 1000, + "params": { + "vmSize": { + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] + } + } + }, + }) assert str(tunable_groups) == str(tunables_other) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 3433f4a735..4191f37d89 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -10,7 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups __all__ = [ - "Tunable", - "TunableValue", - "TunableGroups", + 'Tunable', + 'TunableValue', + 'TunableGroups', ] diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index 797510a087..fee4fd5841 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -93,12 +93,10 @@ def __eq__(self, other: object) -> bool: return False # TODO: May need to provide logic to relax the equality check on the # tunables (e.g. "compatible" vs. "equal"). - return ( - self._name == other._name - and self._cost == other._cost - and self._is_updated == other._is_updated - and self._tunables == other._tunables - ) + return (self._name == other._name and + self._cost == other._cost and + self._is_updated == other._is_updated and + self._tunables == other._tunables) def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ @@ -236,11 +234,7 @@ def __contains__(self, tunable: Union[str, Tunable]) -> bool: def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value - def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] - ) -> TunableValue: - value: TunableValue = ( - tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value - ) + def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: + value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value self._is_updated |= self.get_tunable(tunable).update(value) return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 1886d09597..1ebd70dfa4 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -107,9 +107,7 @@ def __init__(self, name: str, config: TunableDict): config : dict Python dict that represents a Tunable (e.g., deserialized from JSON) """ - if ( - not isinstance(name, str) or "!" in name - ): # TODO: Use a regex here and in JSON schema + if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name self._type: TunableValueTypeName = config["type"] # required @@ -117,9 +115,7 @@ def __init__(self, name: str, config: TunableDict): raise ValueError(f"Invalid parameter type: {self._type}") self._description = config.get("description") self._default = config["default"] - self._default = ( - self.dtype(self._default) if self._default is not None else self._default - ) + self._default = self.dtype(self._default) if self._default is not None else self._default self._values = config.get("values") if self._values: self._values = [str(v) if v is not None else v for v in self._values] @@ -158,9 +154,7 @@ def _sanity_check(self) -> None: else: raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}") if not self.is_valid(self.default): - raise ValueError( - f"Invalid default value for tunable {self}: {self.default}" - ) + raise ValueError(f"Invalid default value for tunable {self}: {self.default}") def _sanity_check_categorical(self) -> None: """ @@ -169,17 +163,11 @@ def _sanity_check_categorical(self) -> None: # pylint: disable=too-complex assert self.is_categorical if not (self._values and isinstance(self._values, collections.abc.Iterable)): - raise ValueError( - f"Must specify values for the categorical type tunable {self}" - ) + raise ValueError(f"Must specify values for the categorical type tunable {self}") if self._range is not None: - raise ValueError( - f"Range must be None for the categorical type tunable {self}" - ) + raise ValueError(f"Range must be None for the categorical type tunable {self}") if len(set(self._values)) != len(self._values): - raise ValueError( - f"Values must be unique for the categorical type tunable {self}" - ) + raise ValueError(f"Values must be unique for the categorical type tunable {self}") if self._special: raise ValueError(f"Categorical tunable cannot have special values: {self}") if self._range_weight is not None: @@ -187,13 +175,9 @@ def _sanity_check_categorical(self) -> None: if self._log is not None: raise ValueError(f"Categorical tunable cannot have log parameter: {self}") if self._quantization is not None: - raise ValueError( - f"Categorical tunable cannot have quantization parameter: {self}" - ) + raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}") if self._distribution is not None: - raise ValueError( - f"Categorical parameters do not support `distribution`: {self}" - ) + raise ValueError(f"Categorical parameters do not support `distribution`: {self}") if self._weights: if len(self._weights) != len(self._values): raise ValueError(f"Must specify weights for all values: {self}") @@ -207,31 +191,21 @@ def _sanity_check_numerical(self) -> None: # pylint: disable=too-complex,too-many-branches assert self.is_numerical if self._values is not None: - raise ValueError( - f"Values must be None for the numerical type tunable {self}" - ) + raise ValueError(f"Values must be None for the numerical type tunable {self}") if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: raise ValueError(f"Invalid range for tunable {self}: {self._range}") if self._quantization is not None: if self.dtype == int: if not isinstance(self._quantization, int): - raise ValueError( - f"Quantization of a int param should be an int: {self}" - ) + raise ValueError(f"Quantization of a int param should be an int: {self}") if self._quantization <= 1: raise ValueError(f"Number of quantization points is <= 1: {self}") if self.dtype == float: if not isinstance(self._quantization, (float, int)): - raise ValueError( - f"Quantization of a float param should be a float or int: {self}" - ) + raise ValueError(f"Quantization of a float param should be a float or int: {self}") if self._quantization <= 0: raise ValueError(f"Number of quantization points is <= 0: {self}") - if self._distribution is not None and self._distribution not in { - "uniform", - "normal", - "beta", - }: + if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: raise ValueError(f"Invalid distribution: {self}") if self._distribution_params and self._distribution is None: raise ValueError(f"Must specify the distribution: {self}") @@ -243,9 +217,7 @@ def _sanity_check_numerical(self) -> None: if any(w < 0 for w in self._weights + [self._range_weight]): raise ValueError(f"All weights must be non-negative: {self}") elif self._range_weight is not None: - raise ValueError( - f"Must specify both weights and range_weight or none: {self}" - ) + raise ValueError(f"Must specify both weights and range_weight or none: {self}") def __repr__(self) -> str: """ @@ -279,14 +251,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Tunable): return False return bool( - self._name == other._name - and self._type == other._type - and self._current_value == other._current_value + self._name == other._name and + self._type == other._type and + self._current_value == other._current_value ) - def __lt__( - self, other: object - ) -> bool: # pylint: disable=too-many-return-statements + def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements """ Compare the two Tunable objects. We mostly need this to create a canonical list of tunable objects when hashing a TunableGroup. @@ -366,33 +336,18 @@ def value(self, value: TunableValue) -> TunableValue: assert value is not None coerced_value = self.dtype(value) except Exception: - _LOG.error( - "Impossible conversion: %s %s <- %s %s", - self._type, - self._name, - type(value), - value, - ) + _LOG.error("Impossible conversion: %s %s <- %s %s", + self._type, self._name, type(value), value) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: - _LOG.error( - "Loss of precision: %s %s <- %s %s", - self._type, - self._name, - type(value), - value, - ) + _LOG.error("Loss of precision: %s %s <- %s %s", + self._type, self._name, type(value), value) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): - _LOG.error( - "Invalid assignment: %s %s <- %s %s", - self._type, - self._name, - type(value), - value, - ) + _LOG.error("Invalid assignment: %s %s <- %s %s", + self._type, self._name, type(value), value) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") self._current_value = coerced_value @@ -437,9 +392,7 @@ def is_valid(self, value: TunableValue) -> bool: if isinstance(value, (int, float)): return self.in_range(value) or value in self._special else: - raise ValueError( - f"Invalid value type for tunable {self}: {value}={type(value)}" - ) + raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}") else: raise ValueError(f"Invalid parameter type: {self._type}") @@ -450,10 +403,10 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: Return False if the tunable or value is categorical or None. """ return ( - isinstance(value, (float, int)) - and self.is_numerical - and self._range is not None - and bool(self._range[0] <= value <= self._range[1]) + isinstance(value, (float, int)) and + self.is_numerical and + self._range is not None and + bool(self._range[0] <= value <= self._range[1]) ) @property @@ -673,19 +626,12 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: # Be sure to return python types instead of numpy types. cardinality = self.cardinality assert isinstance(cardinality, int) - return ( - float(x) - for x in np.linspace( - start=num_range[0], - stop=num_range[1], - num=cardinality, - endpoint=True, - ) - ) + return (float(x) for x in np.linspace(start=num_range[0], + stop=num_range[1], + num=cardinality, + endpoint=True)) assert self.type == "int", f"Unhandled tunable type: {self}" - return range( - int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1) - ) + return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) @property def cardinality(self) -> Union[int, float]: @@ -760,9 +706,7 @@ def categories(self) -> List[Optional[str]]: return self._values @property - def values( - self, - ) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: + def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: """ Gets the categories or quantized values for this tunable. diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index b48da6fccb..0bd58c8269 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -30,11 +30,9 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = ( - {} - ) # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) self._tunable_groups: Dict[str, CovariantTunableGroup] = {} - for name, group_config in config.items(): + for (name, group_config) in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) def __bool__(self) -> bool: @@ -83,15 +81,11 @@ def _add_group(self, group: CovariantTunableGroup) -> None: ---------- group : CovariantTunableGroup """ - assert ( - group.name not in self._tunable_groups - ), f"Duplicate covariant tunable group name {group.name} in {self}" + assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" self._tunable_groups[group.name] = group for tunable in group.get_tunables(): if tunable.name in self._index: - raise ValueError( - f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" - ) + raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") self._index[tunable.name] = group def merge(self, tunables: "TunableGroups") -> "TunableGroups": @@ -125,10 +119,8 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": # Check that there's no overlap in the tunables. # But allow for differing current values. if not self._tunable_groups[group.name].equals_defaults(group): - raise ValueError( - f"Overlapping covariant tunable group name {group.name} " - + "in {self._tunable_groups[group.name]} and {tunables}" - ) + raise ValueError(f"Overlapping covariant tunable group name {group.name} " + + "in {self._tunable_groups[group.name]} and {tunables}") return self def __repr__(self) -> str: @@ -140,17 +132,10 @@ def __repr__(self) -> str: string : str A human-readable version of the TunableGroups. """ - return ( - "{ " - + ", ".join( - f"{group.name}::{tunable}" - for group in sorted( - self._tunable_groups.values(), key=lambda g: (-g.cost, g.name) - ) - for tunable in sorted(group._tunables.values()) - ) - + " }" - ) + return "{ " + ", ".join( + f"{group.name}::{tunable}" + for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) + for tunable in sorted(group._tunables.values())) + " }" def __contains__(self, tunable: Union[str, Tunable]) -> bool: """ @@ -166,17 +151,13 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] - def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] - ) -> TunableValue: + def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: """ Update the current value of a single tunable parameter. """ # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable - value: TunableValue = ( - tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value - ) + value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value self._index[name][name] = value return self._index[name][name] @@ -190,13 +171,9 @@ def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, Non An iterator over all tunables in all groups. Each element is a 2-tuple of an instance of the Tunable parameter and covariant group it belongs to. """ - return ( - (group.get_tunable(name), group) for (name, group) in self._index.items() - ) + return ((group.get_tunable(name), group) for (name, group) in self._index.items()) - def get_tunable( - self, tunable: Union[str, Tunable] - ) -> Tuple[Tunable, CovariantTunableGroup]: + def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: """ Access the entire Tunable (not just its value) and its covariant group. Throw KeyError if the tunable is not found. @@ -251,17 +228,12 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": tunables = TunableGroups() for name in group_names: if name not in self._tunable_groups: - raise KeyError( - f"Unknown covariant group name '{name}' in tunable group {self}" - ) + raise KeyError(f"Unknown covariant group name '{name}' in tunable group {self}") tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values( - self, - group_names: Optional[Iterable[str]] = None, - into_params: Optional[Dict[str, TunableValue]] = None, - ) -> Dict[str, TunableValue]: + def get_param_values(self, group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -300,10 +272,8 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: is_updated : bool True if any of the specified tunable groups has been updated, False otherwise. """ - return any( - self._tunable_groups[name].is_updated() - for name in (group_names or self.get_covariant_group_names()) - ) + return any(self._tunable_groups[name].is_updated() + for name in (group_names or self.get_covariant_group_names())) def is_defaults(self) -> bool: """ @@ -315,9 +285,7 @@ def is_defaults(self) -> bool: """ return all(group.is_defaults() for group in self._tunable_groups.values()) - def restore_defaults( - self, group_names: Optional[Iterable[str]] = None - ) -> "TunableGroups": + def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": """ Restore all tunable parameters to their default values. @@ -331,7 +299,7 @@ def restore_defaults( self : TunableGroups Self-reference for chaining. """ - for name in group_names or self.get_covariant_group_names(): + for name in (group_names or self.get_covariant_group_names()): self._tunable_groups[name].restore_defaults() return self @@ -349,7 +317,7 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": self : TunableGroups Self-reference for chaining. """ - for name in group_names or self.get_covariant_group_names(): + for name in (group_names or self.get_covariant_group_names()): self._tunable_groups[name].reset_is_updated() return self diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 2892543e5f..531988be97 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -42,9 +42,7 @@ from mlos_bench.storage.base_storage import Storage # BaseTypeVar is a generic with a constraint of the three base classes. -BaseTypeVar = TypeVar( - "BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage" -) +BaseTypeVar = TypeVar("BaseTypeVar", "Environment", "Optimizer", "Scheduler", "Service", "Storage") BaseTypes = Union["Environment", "Optimizer", "Scheduler", "Service", "Storage"] @@ -73,12 +71,8 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> return dest -def merge_parameters( - *, - dest: dict, - source: Optional[dict] = None, - required_keys: Optional[Iterable[str]] = None, -) -> dict: +def merge_parameters(*, dest: dict, source: Optional[dict] = None, + required_keys: Optional[Iterable[str]] = None) -> dict: """ Merge the source config dict into the destination config. Pick from the source configs *ONLY* the keys that are already present @@ -138,9 +132,8 @@ def path_join(*args: str, abs_path: bool = False) -> str: return os.path.normpath(path).replace("\\", "/") -def prepare_class_load( - config: dict, global_config: Optional[Dict[str, Any]] = None -) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load(config: dict, + global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -162,11 +155,8 @@ def prepare_class_load( merge_parameters(dest=class_config, source=global_config) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Instantiating: %s with config:\n%s", - class_name, - json.dumps(class_config, indent=2), - ) + _LOG.debug("Instantiating: %s with config:\n%s", + class_name, json.dumps(class_config, indent=2)) return (class_name, class_config) @@ -197,9 +187,8 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". -def instantiate_from_config( - base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any -) -> BaseTypeVar: +def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, + *args: Any, **kwargs: Any) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -231,9 +220,7 @@ def instantiate_from_config( return ret -def check_required_params( - config: Mapping[str, Any], required_params: Iterable[str] -) -> None: +def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: """ Check if all required parameters are present in the configuration. Raise ValueError if any of the parameters are missing. @@ -251,8 +238,7 @@ def check_required_params( if missing_params: raise ValueError( "The following parameters must be provided in the configuration" - + f" or as command line arguments: {missing_params}" - ) + + f" or as command line arguments: {missing_params}") def get_git_info(path: str = __file__) -> Tuple[str, str, str]: @@ -271,14 +257,11 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]: """ dirname = os.path.dirname(path) git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True - ).strip() + ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() git_commit = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "HEAD"], text=True - ).strip() + ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True - ).strip() + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) return (git_repo, git_commit, rel_path.replace("\\", "/")) @@ -334,9 +317,7 @@ def nullable(func: Callable, value: Optional[Any]) -> Optional[Any]: return None if value is None else func(value) -def utcify_timestamp( - timestamp: datetime, *, origin: Literal["utc", "local"] -) -> datetime: +def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> datetime: """ Augment a timestamp with zoneinfo if missing and convert it to UTC. @@ -374,9 +355,7 @@ def utcify_timestamp( raise ValueError(f"Invalid origin: {origin}") -def utcify_nullable_timestamp( - timestamp: Optional[datetime], *, origin: Literal["utc", "local"] -) -> Optional[datetime]: +def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: """ A nullable version of utcify_timestamp. """ @@ -388,9 +367,7 @@ def utcify_nullable_timestamp( _MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) -def datetime_parser( - datetime_col: pandas.Series, *, origin: Literal["utc", "local"] -) -> pandas.Series: +def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. @@ -424,7 +401,7 @@ def datetime_parser( new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) assert new_datetime_col.dt.tz is not None # And convert it to UTC. - new_datetime_col = new_datetime_col.dt.tz_convert("UTC") + new_datetime_col = new_datetime_col.dt.tz_convert('UTC') if new_datetime_col.isna().any(): raise ValueError(f"Invalid date format in the data: {datetime_col}") if new_datetime_col.le(_MIN_TS).any(): diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index f8acae8c02..96d3d2b6bf 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index b2090424a6..27d844c35b 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -21,24 +21,21 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning( - f"setuptools_scm failed to find git version, using version from version.py: {e}" - ) + warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") # A simple routine to read and adjust the README.md for this module into a format @@ -50,72 +47,62 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } -extra_requires: Dict[str, List[str]] = ( - { # pylint: disable=consider-using-namedtuple-or-dataclass - # Additional tools for extra functionality. - "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], - "ssh": ["asyncssh"], - "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], - "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], - "storage-sql-postgres": ["sqlalchemy", "psycopg2"], - "storage-sql-sqlite": [ - "sqlalchemy" - ], # sqlite3 comes with python, so we don't need to install it. - # Transitive extra_requires from mlos-core. - "flaml": ["flaml[blendsearch]"], - "smac": ["smac"], - } -) +extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass + # Additional tools for extra functionality. + 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], + 'ssh': ['asyncssh'], + 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], + 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], + 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], + 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it. + # Transitive extra_requires from mlos-core. + 'flaml': ['flaml[blendsearch]'], + 'smac': ['smac'], +} # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", - "pytest-lazy-fixtures", - "pytest-docker", - "fasteners", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', + 'pytest-lazy-fixtures', + 'pytest-docker', + 'fasteners', ] setup( version=VERSION, install_requires=[ - "mlos-core==" + VERSION, - "requests", - "json5", - "jsonschema>=4.18.0", - "referencing>=0.29.1", + 'mlos-core==' + VERSION, + 'requests', + 'json5', + 'jsonschema>=4.18.0', 'referencing>=0.29.1', 'importlib_resources;python_version<"3.10"', - ] - + extra_requires[ - "storage-sql-sqlite" - ], # NOTE: For now sqlite is a fallback storage backend, so we always install it. + ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it. extras_require=extra_requires, - **_get_long_desc_from_readme( - "https://github.com/microsoft/MLOS/tree/main/mlos_bench" - ), + **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'), ) diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index b3e248e407..086002af62 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -18,12 +18,12 @@ from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ - "SpaceAdapterType", - "OptimizerFactory", - "BaseOptimizer", - "RandomOptimizer", - "FlamlOptimizer", - "SmacOptimizer", + 'SpaceAdapterType', + 'OptimizerFactory', + 'BaseOptimizer', + 'RandomOptimizer', + 'FlamlOptimizer', + 'SmacOptimizer', ] @@ -45,7 +45,7 @@ class OptimizerType(Enum): # ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) # To address this, we add a test for complete coverage of the enum. ConcreteOptimizer = TypeVar( - "ConcreteOptimizer", + 'ConcreteOptimizer', RandomOptimizer, FlamlOptimizer, SmacOptimizer, @@ -60,15 +60,13 @@ class OptimizerFactory: # pylint: disable=too-few-public-methods @staticmethod - def create( - *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, - optimizer_kwargs: Optional[dict] = None, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None, - ) -> ConcreteOptimizer: # type: ignore[type-var] + def create(*, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, + optimizer_kwargs: Optional[dict] = None, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. @@ -109,7 +107,7 @@ def create( parameter_space=parameter_space, optimization_targets=optimization_targets, space_adapter=space_adapter, - **optimizer_kwargs, + **optimizer_kwargs ) return optimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index d4f59dfa52..5f32219988 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -12,6 +12,6 @@ from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer __all__ = [ - "BaseBayesianOptimizer", - "SmacOptimizer", + 'BaseBayesianOptimizer', + 'SmacOptimizer', ] diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 9d3bcabcb2..76ff0d9b3a 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -19,9 +19,8 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): """Abstract base class defining the interface for Bayesian optimization.""" @abstractmethod - def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def surrogate_predict(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None) -> npt.NDArray: """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). Parameters @@ -32,12 +31,11 @@ def surrogate_predict( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def acquisition_function(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None) -> npt.NDArray: """Invokes the acquisition function from this Bayesian optimizer for the given configuration. Parameters @@ -48,4 +46,4 @@ def acquisition_function( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 4364f4c172..9d8d2a0347 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -29,22 +29,19 @@ class SmacOptimizer(BaseBayesianOptimizer): Wrapper class for SMAC based Bayesian optimization. """ - def __init__( - self, - *, # pylint: disable=too-many-locals,too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - seed: Optional[int] = 0, - run_name: Optional[str] = None, - output_directory: Optional[str] = None, - max_trials: int = 100, - n_random_init: Optional[int] = None, - max_ratio: Optional[float] = None, - use_default_config: bool = False, - n_random_probability: float = 0.1, - ): + def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + seed: Optional[int] = 0, + run_name: Optional[str] = None, + output_directory: Optional[str] = None, + max_trials: int = 100, + n_random_init: Optional[int] = None, + max_ratio: Optional[float] = None, + use_default_config: bool = False, + n_random_probability: float = 0.1): """ Instantiate a new SMAC optimizer wrapper. @@ -127,9 +124,7 @@ def __init__( if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory( - ignore_cleanup_errors=True - ) # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10 except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -151,12 +146,8 @@ def __init__( seed=seed or -1, # if -1, SMAC will generate a random seed internally n_workers=1, # Use a single thread for evaluating trials ) - intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( - scenario, max_config_calls=1 - ) - config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( - scenario, retrain_after=1 - ) + intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1) + config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1) # TODO: When bulk registering prior configs to rewarm the optimizer, # there is a way to inform SMAC's initial design that we have @@ -167,27 +158,27 @@ def __init__( # See Also: #488 initial_design_args: Dict[str, Union[list, int, float, Scenario]] = { - "scenario": scenario, + 'scenario': scenario, # Workaround a bug in SMAC that sets a default arg to a mutable # value that can cause issues when multiple optimizers are # instantiated with the use_default_config option within the same # process that use different ConfigSpaces so that the second # receives the default config from both as an additional config. - "additional_configs": [], + 'additional_configs': [] } if n_random_init is not None: - initial_design_args["n_configs"] = n_random_init + initial_design_args['n_configs'] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - "Number of random initial configs (%d) is " - + "greater than 25%% of max_trials (%d). " - + "Consider setting max_ratio to avoid SMAC overriding n_random_init.", + 'Number of random initial configs (%d) is ' + + 'greater than 25%% of max_trials (%d). ' + + 'Consider setting max_ratio to avoid SMAC overriding n_random_init.', n_random_init, max_trials, ) if max_ratio is not None: assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0 - initial_design_args["max_ratio"] = max_ratio + initial_design_args['max_ratio'] = max_ratio # Use the default InitialDesign from SMAC. # (currently SBOL instead of LatinHypercube due to better uniformity @@ -199,9 +190,7 @@ def __init__( # design when generated a random_design for itself via the # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 - random_design = ProbabilityRandomDesign( - probability=n_random_probability, seed=scenario.seed - ) + random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed) self.base_optimizer = Optimizer_Smac( scenario, @@ -211,8 +200,7 @@ def __init__( random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights - ), + scenario, objective_weights=self._objective_weights), overwrite=True, logging_level=False, # Use the existing logger ) @@ -253,16 +241,10 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None """ # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 - raise RuntimeError("This function should never be called.") - - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + raise RuntimeError('This function should never be called.') + + def _register(self, *, configs: pd.DataFrame, + scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -286,30 +268,20 @@ def _register( ) if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) # Register each trial (one-by-one) - for config, (_i, score) in zip( - self._to_configspace_configs(configs=configs), scores.iterrows() - ): + for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()): # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) - ) - value = TrialValue( - cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS - ) + config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed)) + value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) self.base_optimizer.tell(info, value, save=False) # Save optimizer once we register all configs self.base_optimizer.optimizer.save() - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -331,99 +303,62 @@ def _suggest( ) if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) trial: TrialInfo = self.base_optimizer.ask() trial.config.is_valid_configuration() self.optimizer_parameter_space.check_configuration(trial.config) assert trial.config.config_space == self.optimizer_parameter_space self.trial_info_map[trial.config] = trial - config_df = pd.DataFrame( - [trial.config], columns=list(self.optimizer_parameter_space.keys()) - ) + config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys())) return config_df, None - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() - def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: from smac.utils.configspace import ( convert_configurations_to_array, # pylint: disable=import-outside-toplevel ) if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if self._space_adapter and not isinstance(self._space_adapter, IdentityAdapter): - raise NotImplementedError( - "Space adapter not supported for surrogate_predict." - ) + raise NotImplementedError("Space adapter not supported for surrogate_predict.") # pylint: disable=protected-access if len(self._observations) <= self.base_optimizer._initial_design._n_configs: raise RuntimeError( - "Surrogate model can make predictions *only* after all initial points have been evaluated " - + f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" - ) + 'Surrogate model can make predictions *only* after all initial points have been evaluated ' + + f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}') if self.base_optimizer._config_selector._model is None: - raise RuntimeError("Surrogate model is not yet trained") + raise RuntimeError('Surrogate model is not yet trained') - config_array: npt.NDArray = convert_configurations_to_array( - self._to_configspace_configs(configs=configs) - ) - mean_predictions, _ = self.base_optimizer._config_selector._model.predict( - config_array - ) - return mean_predictions.reshape( - -1, - ) + config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs)) + mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array) + return mean_predictions.reshape(-1,) - def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if self._space_adapter: raise NotImplementedError() # pylint: disable=protected-access if self.base_optimizer._config_selector._acquisition_function is None: - raise RuntimeError("Acquisition function is not yet initialized") + raise RuntimeError('Acquisition function is not yet initialized') cs_configs: list = self._to_configspace_configs(configs=configs) - return self.base_optimizer._config_selector._acquisition_function( - cs_configs - ).reshape( - -1, - ) + return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,) def cleanup(self) -> None: - if ( - hasattr(self, "_temp_output_directory") - and self._temp_output_directory is not None - ): + if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None: self._temp_output_directory.cleanup() self._temp_output_directory = None - def _to_configspace_configs( - self, *, configs: pd.DataFrame - ) -> List[ConfigSpace.Configuration]: + def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace.Configuration]: """Convert a dataframe of configs to a list of ConfigSpace configs. Parameters @@ -437,8 +372,6 @@ def _to_configspace_configs( List of ConfigSpace configs. """ return [ - ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict() - ) - for (_, config) in configs.astype("O").iterrows() + ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) + for (_, config) in configs.astype('O').iterrows() ] diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 638613c43d..273c89eecc 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -33,16 +33,13 @@ class FlamlOptimizer(BaseOptimizer): # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" - def __init__( - self, - *, # pylint: disable=too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - low_cost_partial_config: Optional[dict] = None, - seed: Optional[int] = None, - ): + def __init__(self, *, # pylint: disable=too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + low_cost_partial_config: Optional[dict] = None, + seed: Optional[int] = None): """ Create an MLOS wrapper for FLAML. @@ -85,22 +82,14 @@ def __init__( configspace_to_flaml_space, ) - self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space( - self.optimizer_parameter_space - ) + self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) self.low_cost_partial_config = low_cost_partial_config self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {} self._suggested_config: Optional[dict] - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -118,34 +107,21 @@ def _register( Not Yet Implemented. """ if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if metadata is not None: - warn( - f"Not Implemented: Ignoring metadata {list(metadata.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning) - for (_, config), (_, score) in zip( - configs.astype("O").iterrows(), scores.iterrows() - ): + for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()): cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict() - ) + self.optimizer_parameter_space, values=config.to_dict()) if cs_config in self.evaluated_samples: warn(f"Configuration {config} was already registered", UserWarning) self.evaluated_samples[cs_config] = EvaluatedSample( config=config.to_dict(), - score=float( - np.average(score.astype(float), weights=self._objective_weights) - ), + score=float(np.average(score.astype(float), weights=self._objective_weights)), ) - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -164,20 +140,12 @@ def _suggest( Not implemented. """ if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) config: dict = self._get_next_config() return pd.DataFrame(config, index=[0]), None - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: @@ -232,14 +200,16 @@ def _get_next_config(self) -> dict: dict(normalize_config(self.optimizer_parameter_space, conf)) for conf in self.evaluated_samples ] - evaluated_rewards = [s.score for s in self.evaluated_samples.values()] + evaluated_rewards = [ + s.score for s in self.evaluated_samples.values() + ] # Warm start FLAML optimizer self._suggested_config = None tune.run( self._target_function, config=self.flaml_parameter_space, - mode="min", + mode='min', metric=self._METRIC_NAME, points_to_evaluate=points_to_evaluate, evaluated_rewards=evaluated_rewards, @@ -248,6 +218,6 @@ def _get_next_config(self) -> dict: verbose=0, ) if self._suggested_config is None: - raise RuntimeError("FLAML did not produce a suggestion") + raise RuntimeError('FLAML did not produce a suggestion') return self._suggested_config # type: ignore[unreachable] diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 8e80de16f1..4ab9db5a2f 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -24,14 +24,11 @@ class BaseOptimizer(metaclass=ABCMeta): Optimizer abstract base class defining the basic interface. """ - def __init__( - self, - *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - ): + def __init__(self, *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None): """ Create a new instance of the base optimizer. @@ -47,37 +44,21 @@ def __init__( The space adapter class to employ for parameter space transformations. """ self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space - self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = ( - parameter_space - if space_adapter is None - else space_adapter.target_parameter_space - ) - - if ( - space_adapter is not None - and space_adapter.orig_parameter_space != parameter_space - ): - raise ValueError( - "Given parameter space differs from the one given to space adapter" - ) + self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \ + parameter_space if space_adapter is None else space_adapter.target_parameter_space + + if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space: + raise ValueError("Given parameter space differs from the one given to space adapter") self._optimization_targets = optimization_targets self._objective_weights = objective_weights - if objective_weights is not None and len(objective_weights) != len( - optimization_targets - ): - raise ValueError( - "Number of weights must match the number of optimization targets" - ) + if objective_weights is not None and len(objective_weights) != len(optimization_targets): + raise ValueError("Number of weights must match the number of optimization targets") self._space_adapter: Optional[BaseSpaceAdapter] = space_adapter - self._observations: List[ - Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]] - ] = [] + self._observations: List[Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]] = [] self._has_context: Optional[bool] = None - self._pending_observations: List[ - Tuple[pd.DataFrame, Optional[pd.DataFrame]] - ] = [] + self._pending_observations: List[Tuple[pd.DataFrame, Optional[pd.DataFrame]]] = [] def __repr__(self) -> str: return f"{self.__class__.__name__}(space_adapter={self.space_adapter})" @@ -87,14 +68,8 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: """Get the space adapter instance (if any).""" return self._space_adapter - def register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. Parameters @@ -112,39 +87,29 @@ def register( """ # Do some input validation. assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(scores.columns) == set( - self._optimization_targets - ), "Mismatched optimization targets." - assert self._has_context is None or self._has_context ^ ( - context is None - ), "Context must always be added or never be added." - assert len(configs) == len(scores), "Mismatched number of configs and scores." + assert set(scores.columns) == set(self._optimization_targets), \ + "Mismatched optimization targets." + assert self._has_context is None or self._has_context ^ (context is None), \ + "Context must always be added or never be added." + assert len(configs) == len(scores), \ + "Mismatched number of configs and scores." if context is not None: - assert len(configs) == len( - context - ), "Mismatched number of configs and context." - assert configs.shape[1] == len( - self.parameter_space.values() - ), "Mismatched configuration shape." + assert len(configs) == len(context), \ + "Mismatched number of configs and context." + assert configs.shape[1] == len(self.parameter_space.values()), \ + "Mismatched configuration shape." self._observations.append((configs, scores, context)) self._has_context = context is not None if self._space_adapter: configs = self._space_adapter.inverse_transform(configs) - assert configs.shape[1] == len( - self.optimizer_parameter_space.values() - ), "Mismatched configuration shape after inverse transform." + assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \ + "Mismatched configuration shape after inverse transform." return self._register(configs=configs, scores=scores, context=context) @abstractmethod - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -157,11 +122,10 @@ def _register( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def suggest( - self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def suggest(self, *, context: Optional[pd.DataFrame] = None, + defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. @@ -179,31 +143,24 @@ def suggest( Pandas dataframe with a single row. Column names are the parameter names. """ if defaults: - configuration = config_to_dataframe( - self.parameter_space.get_default_configuration() - ) + configuration = config_to_dataframe(self.parameter_space.get_default_configuration()) metadata = None if self.space_adapter is not None: configuration = self.space_adapter.inverse_transform(configuration) else: configuration, metadata = self._suggest(context=context) - assert ( - len(configuration) == 1 - ), "Suggest must return a single configuration." - assert set(configuration.columns).issubset( - set(self.optimizer_parameter_space) - ), "Optimizer suggested a configuration that does not match the expected parameter space." + assert len(configuration) == 1, \ + "Suggest must return a single configuration." + assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \ + "Optimizer suggested a configuration that does not match the expected parameter space." if self._space_adapter: configuration = self._space_adapter.transform(configuration) - assert set(configuration.columns).issubset( - set(self.parameter_space) - ), "Space adapter produced a configuration that does not match the expected parameter space." + assert set(configuration.columns).issubset(set(self.parameter_space)), \ + "Space adapter produced a configuration that does not match the expected parameter space." return configuration, metadata @abstractmethod - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -219,16 +176,12 @@ def _suggest( metadata : Optional[pd.DataFrame] The metadata associated with the given configuration used for evaluations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs as "pending". That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. @@ -242,11 +195,9 @@ def register_pending( metadata : Optional[pd.DataFrame] Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def get_observations( - self, - ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Returns the observations as a triplet of DataFrames (config, score, context). @@ -257,23 +208,13 @@ def get_observations( """ if len(self._observations) == 0: raise ValueError("No observations registered yet.") - configs = pd.concat( - [config for config, _, _ in self._observations] - ).reset_index(drop=True) - scores = pd.concat([score for _, score, _ in self._observations]).reset_index( - drop=True - ) - contexts = pd.concat( - [ - pd.DataFrame() if context is None else context - for _, _, context in self._observations - ] - ).reset_index(drop=True) + configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True) + scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True) + contexts = pd.concat([pd.DataFrame() if context is None else context + for _, _, context in self._observations]).reset_index(drop=True) return (configs, scores, contexts if len(contexts.columns) > 0 else None) - def get_best_observations( - self, *, n_max: int = 1 - ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, context). Default is N=1. The columns are ordered in ASCENDING order of the optimization targets. @@ -292,14 +233,9 @@ def get_best_observations( if len(self._observations) == 0: raise ValueError("No observations registered yet.") (configs, scores, contexts) = self.get_observations() - idx = scores.nsmallest( - n_max, columns=self._optimization_targets, keep="first" - ).index - return ( - configs.loc[idx], - scores.loc[idx], - None if contexts is None else contexts.loc[idx], - ) + idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index + return (configs.loc[idx], scores.loc[idx], + None if contexts is None else contexts.loc[idx]) def cleanup(self) -> None: """ @@ -317,7 +253,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: j = 0 for param in self.optimizer_parameter_space.values(): if isinstance(param, ConfigSpace.CategoricalHyperparameter): - for offset, val in enumerate(param.choices): + for (offset, val) in enumerate(param.choices): if config[i][j + offset] == 1: df_dict[param.name].append(val) break diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index f1ce489b28..0af785ef20 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -24,14 +24,8 @@ class RandomOptimizer(BaseOptimizer): The parameter space to optimize. """ - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Doesn't do anything on the RandomOptimizer except storing configs for logging. @@ -51,20 +45,12 @@ def _register( Not Yet Implemented. """ if context is not None: - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if metadata is not None: - warn( - f"Not Implemented: Ignoring context {list(metadata.columns)}", - UserWarning, - ) + warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning) # should we pop them from self.pending_observations? - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -84,23 +70,10 @@ def _suggest( """ if context is not None: # not sure how that works here? - warn( - f"Not Implemented: Ignoring context {list(context.columns)}", - UserWarning, - ) - return ( - pd.DataFrame( - dict(self.optimizer_parameter_space.sample_configuration()), index=[0] - ), - None, - ) - - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) + return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None + + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() # self._pending_observations.append((configs, context)) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 73e7f37dc3..2e2f585590 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -15,8 +15,8 @@ from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter __all__ = [ - "IdentityAdapter", - "LlamaTuneAdapter", + 'IdentityAdapter', + 'LlamaTuneAdapter', ] @@ -35,7 +35,7 @@ class SpaceAdapterType(Enum): # ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType]) # To address this, we add a test for complete coverage of the enum. ConcreteSpaceAdapter = TypeVar( - "ConcreteSpaceAdapter", + 'ConcreteSpaceAdapter', IdentityAdapter, LlamaTuneAdapter, ) @@ -47,12 +47,10 @@ class SpaceAdapterFactory: # pylint: disable=too-few-public-methods @staticmethod - def create( - *, - parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None, - ) -> ConcreteSpaceAdapter: # type: ignore[type-var] + def create(*, + parameter_space: ConfigSpace.ConfigurationSpace, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var] """ Create a new space adapter instance, given the parameter space and potential space adapter options. @@ -77,7 +75,8 @@ def create( space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, **space_adapter_kwargs + orig_parameter_space=parameter_space, + **space_adapter_kwargs ) return space_adapter diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index cc7b22b708..6c3a86fc8a 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -22,9 +22,7 @@ class BaseSpaceAdapter(metaclass=ABCMeta): """ def __init__(self, *, orig_parameter_space: ConfigSpace.ConfigurationSpace): - self._orig_parameter_space: ConfigSpace.ConfigurationSpace = ( - orig_parameter_space - ) + self._orig_parameter_space: ConfigSpace.ConfigurationSpace = orig_parameter_space self._random_state = orig_parameter_space.random def __repr__(self) -> str: @@ -48,7 +46,7 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: """ Target parameter space that is fed to the underlying optimizer. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @@ -66,7 +64,7 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Pandas dataframe with a single row, containing the translated configuration. Column names are the parameter names of the original parameter space. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: @@ -86,4 +84,4 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: Dataframe of the translated configurations / parameters. The columns are the parameter names of the target parameter space and the rows are the configurations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 9c98b772ec..4d3a925cbc 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -19,9 +19,7 @@ from mlos_core.util import normalize_config -class LlamaTuneAdapter( - BaseSpaceAdapter -): # pylint: disable=too-many-instance-attributes +class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes """ Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. @@ -30,23 +28,18 @@ class LlamaTuneAdapter( DEFAULT_NUM_LOW_DIMS = 16 """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" - DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 + DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 """Default percentage of bias for each special parameter value""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 """Default number of (max) unique values of each parameter, when space discretization is used""" - def __init__( - self, - *, - orig_parameter_space: ConfigSpace.ConfigurationSpace, - num_low_dims: int = DEFAULT_NUM_LOW_DIMS, - special_param_values: Optional[dict] = None, - max_unique_values_per_param: Optional[ - int - ] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, - use_approximate_reverse_mapping: bool = False, - ): + def __init__(self, *, + orig_parameter_space: ConfigSpace.ConfigurationSpace, + num_low_dims: int = DEFAULT_NUM_LOW_DIMS, + special_param_values: Optional[dict] = None, + max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, + use_approximate_reverse_mapping: bool = False): """ Create a space adapter that employs LlamaTune's techniques. @@ -65,9 +58,7 @@ def __init__( super().__init__(orig_parameter_space=orig_parameter_space) if num_low_dims >= len(orig_parameter_space): - raise ValueError( - "Number of target config space dimensions should be less than those of original config space." - ) + raise ValueError("Number of target config space dimensions should be less than those of original config space.") # Validate input special param values dict special_param_values = special_param_values or {} @@ -88,9 +79,7 @@ def __init__( self._sigma_vector = self._random_state.choice([-1, 1], num_orig_dims) # Used to retrieve the low-dim point, given the high-dim one - self._suggested_configs: Dict[ - ConfigSpace.Configuration, ConfigSpace.Configuration - ] = {} + self._suggested_configs: Dict[ConfigSpace.Configuration, ConfigSpace.Configuration] = {} self._pinv_matrix: npt.NDArray self._use_approximate_reverse_mapping = use_approximate_reverse_mapping @@ -101,10 +90,9 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] - for _, config in configurations.astype("O").iterrows(): + for (_, config) in configurations.astype('O').iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict() - ) + self.orig_parameter_space, values=config.to_dict()) target_config = self._suggested_configs.get(configuration, None) # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection @@ -112,22 +100,16 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart. if target_config is None: # Inherently it is not supported to register points, which were not suggested by the optimizer. - if ( - configuration - == self.orig_parameter_space.get_default_configuration() - ): + if configuration == self.orig_parameter_space.get_default_configuration(): # Default configuration should always be registerable. pass elif not self._use_approximate_reverse_mapping: - raise ValueError( - f"{repr(configuration)}\n" - "The above configuration was not suggested by the optimizer. " - "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " - "previously by the optimizer can be registered." - ) + raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. " + "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " + "previously by the optimizer can be registered.") # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. - if getattr(self, "_pinv_matrix", None) is None: + if getattr(self, '_pinv_matrix', None) is None: self._try_generate_approx_inverse_mapping() # Replace NaNs with zeros for inactive hyperparameters @@ -136,43 +118,29 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # NOTE: applying special value biasing is not possible vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) - target_config = ConfigSpace.Configuration( - self.target_parameter_space, vector=target_config_vector - ) + target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector) target_configurations.append(target_config) - return pd.DataFrame( - target_configurations, columns=list(self.target_parameter_space.keys()) - ) + return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys())) def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: if len(configuration) != 1: - raise ValueError( - "Configuration dataframe must contain exactly 1 row. " - f"Found {len(configuration)} rows." - ) + raise ValueError("Configuration dataframe must contain exactly 1 row. " + f"Found {len(configuration)} rows.") target_values_dict = configuration.iloc[0].to_dict() - target_configuration = ConfigSpace.Configuration( - self.target_parameter_space, values=target_values_dict - ) + target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict) orig_values_dict = self._transform(target_values_dict) - orig_configuration = normalize_config( - self.orig_parameter_space, orig_values_dict - ) + orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict) # Add to inverse dictionary -- needed for registering the performance later self._suggested_configs[orig_configuration] = target_configuration - return pd.DataFrame( - [list(orig_configuration.values())], columns=list(orig_configuration.keys()) - ) + return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) - def _construct_low_dim_space( - self, num_low_dims: int, max_unique_values_per_param: Optional[int] - ) -> None: + def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: """Constructs the low-dimensional parameter (potentially discretized) search space. Parameters @@ -188,9 +156,7 @@ def _construct_low_dim_space( q_scaler = None if max_unique_values_per_param is None: hyperparameters = [ - ConfigSpace.UniformFloatHyperparameter( - name=f"dim_{idx}", lower=-1, upper=1 - ) + ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1) for idx in range(num_low_dims) ] else: @@ -198,9 +164,7 @@ def _construct_low_dim_space( # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. hyperparameters = [ - ConfigSpace.UniformIntegerHyperparameter( - name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param - ) + ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param) for idx in range(num_low_dims) ] @@ -213,12 +177,8 @@ def _construct_low_dim_space( self._q_scaler = q_scaler # Construct low-dimensional parameter search space - config_space = ConfigSpace.ConfigurationSpace( - name=self.orig_parameter_space.name - ) - config_space.random = ( - self._random_state - ) # use same random state as in original parameter space + config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) + config_space.random = self._random_state # use same random state as in original parameter space config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -249,21 +209,17 @@ def _transform(self, configuration: dict) -> dict: for idx in range(len(original_parameters)) ] # Scale parameter values to [0, 1] - original_config_values = self._config_scaler.transform( - [original_config_values] - )[0] + original_config_values = self._config_scaler.transform([original_config_values])[0] original_config = {} for param, norm_value in zip(original_parameters, original_config_values): # Clip value to force it to fall in [0, 1] # NOTE: HeSBO projection ensures that theoretically but due to # floating point ops nuances this is not always guaranteed - value = max( - 0.0, min(1.0, norm_value) - ) # pylint: disable=redefined-loop-name + value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name if isinstance(param, ConfigSpace.CategoricalHyperparameter): - index = int(value * len(param.choices)) # truncate integer part + index = int(value * len(param.choices)) # truncate integer part index = max(0, min(len(param.choices) - 1, index)) # NOTE: potential rounding here would be unfair to first & last values orig_value = param.choices[index] @@ -271,20 +227,16 @@ def _transform(self, configuration: dict) -> dict: if param.name in self._special_param_values_dict: value = self._special_param_value_scaler(param, value) - orig_value = param._transform(value) # pylint: disable=protected-access + orig_value = param._transform(value) # pylint: disable=protected-access orig_value = max(param.lower, min(param.upper, orig_value)) else: - raise NotImplementedError( - "Only Categorical, Integer, and Float hyperparameters are currently supported." - ) + raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.") original_config[param.name] = orig_value return original_config - def _special_param_value_scaler( - self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float - ) -> float: + def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. Parameters @@ -303,20 +255,17 @@ def _special_param_value_scaler( special_values_list = self._special_param_values_dict[param.name] # Check if input value corresponds to some special value - perc_sum = 0.0 + perc_sum = 0. ret: float for special_value, biasing_perc in special_values_list: perc_sum += biasing_perc if input_value < perc_sum: - ret = param._inverse_transform( - special_value - ) # pylint: disable=protected-access + ret = param._inverse_transform(special_value) # pylint: disable=protected-access return ret # Scale input value uniformly to non-special values - ret = param._inverse_transform( # pylint: disable=protected-access - param._transform_scalar((input_value - perc_sum) / (1 - perc_sum)) - ) # pylint: disable=protected-access + ret = param._inverse_transform( # pylint: disable=protected-access + param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access return ret # pylint: disable=too-complex,too-many-branches @@ -345,79 +294,46 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non hyperparameter = self.orig_parameter_space[param] if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter): - raise NotImplementedError( - error_prefix + f"Parameter '{param}' is not supported. " - "Only Integer Hyperparameters are currently supported." - ) + raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. " + "Only Integer Hyperparameters are currently supported.") if isinstance(value, int): # User specifies a single special value -- default biasing percentage is used - tuple_list = [ - (value, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) - ] + tuple_list = [(value, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE)] elif isinstance(value, tuple) and [type(v) for v in value] == [int, float]: # User specifies both special value and biasing percentage tuple_list = [value] elif isinstance(value, list) and value: if all(isinstance(t, int) for t in value): # User specifies list of special values - tuple_list = [ - (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) - for v in value - ] - elif all( - isinstance(t, tuple) and [type(v) for v in t] == [int, float] - for t in value - ): + tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value] + elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value): # User specifies list of tuples; each tuple defines the special value and the biasing percentage tuple_list = value else: - raise ValueError( - error_prefix - + f"Invalid format in value list for parameter '{param}'. " - f"Special value list should contain either integers, or (special value, biasing %) tuples." - ) + raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. " + f"Special value list should contain either integers, or (special value, biasing %) tuples.") else: - raise ValueError( - error_prefix - + f"Invalid format for parameter '{param}'. Dict value should be " - "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples." - ) + raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " + "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.") # Are user-specified special values valid? - if not all( - hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list - ): - raise ValueError( - error_prefix - + f"One (or more) special values are outside of parameter '{param}' value domain." - ) + if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list): + raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.") # Are user-provided special values unique? if len(set(v for v, _ in tuple_list)) != len(tuple_list): - raise ValueError( - error_prefix - + f"One (or more) special values are defined more than once for parameter '{param}'." - ) + raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.") # Are biasing percentages valid? if not all(0 < perc < 1 for _, perc in tuple_list): - raise ValueError( - error_prefix - + f"One (or more) biasing percentages for parameter '{param}' are invalid: " - "i.e., fall outside (0, 1) range." - ) + raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: " + "i.e., fall outside (0, 1) range.") total_percentage = sum(perc for _, perc in tuple_list) - if total_percentage >= 1.0: - raise ValueError( - error_prefix - + f"Total special values percentage for parameter '{param}' surpass 100%." - ) + if total_percentage >= 1.: + raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.") # ... and reasonable? if total_percentage >= 0.5: - warn( - f"Total special values percentage for parameter '{param}' exceeds 50%.", - UserWarning, - ) + warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning) sanitized_dict[param] = tuple_list @@ -439,12 +355,9 @@ def _try_generate_approx_inverse_mapping(self) -> None: pinv, ) - warn( - "Trying to register a configuration that was not previously suggested by the optimizer. " - + "This inverse configuration transformation is typically not supported. " - + "However, we will try to register this configuration using an *experimental* method.", - UserWarning, - ) + warn("Trying to register a configuration that was not previously suggested by the optimizer. " + + "This inverse configuration transformation is typically not supported. " + + "However, we will try to register this configuration using an *experimental* method.", UserWarning) orig_space_num_dims = len(list(self.orig_parameter_space.values())) target_space_num_dims = len(list(self.target_parameter_space.values())) @@ -458,7 +371,5 @@ def _try_generate_approx_inverse_mapping(self) -> None: try: self._pinv_matrix = pinv(proj_matrix) except LinAlgError as err: - raise RuntimeError( - f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}" - ) from err + raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims) diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 4fec0ed242..d6918f9891 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -27,9 +27,7 @@ FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] -def configspace_to_flaml_space( - config_space: ConfigSpace.ConfigurationSpace, -) -> Dict[str, FlamlDomain]: +def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: """Converts a ConfigSpace.ConfigurationSpace to dict. Parameters @@ -52,23 +50,13 @@ def configspace_to_flaml_space( def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range - return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper - ) + return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): - return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper + 1 - ) + return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: - raise ValueError( - "FLAML doesn't support categorical parameters with non-uniform probabilities." - ) - return flaml.tune.choice(parameter.choices) # TODO: set order? - raise ValueError( - f"Type of parameter {parameter} ({type(parameter)}) not supported." - ) + raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.") + return flaml.tune.choice(parameter.choices) # TODO: set order? + raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") - return { - param.name: _one_parameter_convert(param) for param in config_space.values() - } + return {param.name: _one_parameter_convert(param) for param in config_space.values()} diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index 99dcbf2b2f..a8ad146205 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -21,7 +21,7 @@ from typing_extensions import TypeAlias -T = TypeVar("T") +T = TypeVar('T') def get_all_submodules(pkg: TypeAlias) -> List[str]: @@ -30,9 +30,7 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: Useful for dynamically enumerating subclasses. """ submodules = [] - for _, submodule_name, _ in walk_packages( - pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None - ): + for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): submodules.append(submodule_name) return submodules @@ -43,13 +41,10 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( - s for c in cls.__subclasses__() for s in _get_all_subclasses(c) - ) + s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) -def get_all_concrete_subclasses( - cls: Type[T], pkg_name: Optional[str] = None -) -> List[Type[T]]: +def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: """ Gets a sorted list of all of the concrete subclasses of the given class. Useful for dynamically enumerating expected test cases. @@ -62,11 +57,5 @@ def get_all_concrete_subclasses( pkg = import_module(pkg_name) submodules = get_all_submodules(pkg) assert submodules - return sorted( - [ - subclass - for subclass in _get_all_subclasses(cls) - if not getattr(subclass, "__abstractmethods__", None) - ], - key=lambda c: (c.__module__, c.__name__), - ) + return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], + key=lambda c: (c.__module__, c.__name__)) diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index 775afa2455..c7a94dfcc4 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -17,27 +17,24 @@ @pytest.mark.filterwarnings("error:Not Implemented") -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_context_not_implemented_warning( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict]) -> None: """ Make sure we raise warnings for the functionality that has not been implemented yet. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) suggestion, _metadata = optimizer.suggest() - scores = pd.DataFrame({"score": [1]}) + scores = pd.DataFrame({'score': [1]}) context = pd.DataFrame([["something"]]) with pytest.raises(UserWarning): diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 5efdbb81cf..39231bec5c 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -18,13 +18,9 @@ def configuration_space() -> CS.ConfigurationSpace: # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. - space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) + space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) # Add a categorical hyperparameter with 3 possible values. - space.add_hyperparameter( - CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"]) - ) + space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"])) # Add a discrete input dimension between 0 and 10. - space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10) - ) + space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10)) return space diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index be2e89137d..725d92fbe9 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -23,13 +23,11 @@ def data_frame() -> pd.DataFrame: Toy data frame corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.DataFrame( - { - "y": ["a", "b", "c"], - "x": [0.1, 0.2, 0.3], - "z": [1, 5, 8], - } - ) + return pd.DataFrame({ + 'y': ['a', 'b', 'c'], + 'x': [0.1, 0.2, 0.3], + 'z': [1, 5, 8], + }) @pytest.fixture @@ -38,13 +36,11 @@ def one_hot_data_frame() -> npt.NDArray: One-hot encoding of the `data_frame` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array( - [ - [0.1, 1.0, 0.0, 0.0, 1.0], - [0.2, 0.0, 1.0, 0.0, 5.0], - [0.3, 0.0, 0.0, 1.0, 8.0], - ] - ) + return np.array([ + [0.1, 1.0, 0.0, 0.0, 1.0], + [0.2, 0.0, 1.0, 0.0, 5.0], + [0.3, 0.0, 0.0, 1.0, 8.0], + ]) @pytest.fixture @@ -53,13 +49,11 @@ def series() -> pd.Series: Toy series corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.Series( - { - "y": "b", - "x": 0.4, - "z": 3, - } - ) + return pd.Series({ + 'y': 'b', + 'x': 0.4, + 'z': 3, + }) @pytest.fixture @@ -68,11 +62,9 @@ def one_hot_series() -> npt.NDArray: One-hot encoding of the `series` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array( - [ - [0.4, 0.0, 1.0, 0.0, 3], - ] - ) + return np.array([ + [0.4, 0.0, 1.0, 0.0, 3], + ]) @pytest.fixture @@ -82,56 +74,48 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ return SmacOptimizer( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], ) -def test_to_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray -) -> None: +def test_to_1hot_data_frame(optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray) -> None: """ Toy problem to test one-hot encoding of dataframe. """ assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) -def test_to_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray -) -> None: +def test_to_1hot_series(optimizer: BaseOptimizer, + series: pd.Series, one_hot_series: npt.NDArray) -> None: """ Toy problem to test one-hot encoding of series. """ assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) -def test_from_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray -) -> None: +def test_from_1hot_data_frame(optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray) -> None: """ Toy problem to test one-hot decoding of dataframe. """ - assert ( - optimizer._from_1hot(config=one_hot_data_frame).to_dict() - == data_frame.to_dict() - ) + assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() -def test_from_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray -) -> None: +def test_from_1hot_series(optimizer: BaseOptimizer, + series: pd.Series, + one_hot_series: npt.NDArray) -> None: """ Toy problem to test one-hot decoding of series. """ one_hot_df = optimizer._from_1hot(config=one_hot_series) - assert ( - one_hot_df.shape[0] == 1 - ), f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" + assert one_hot_df.shape[0] == 1, f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" assert one_hot_df.iloc[0].to_dict() == series.to_dict() -def test_round_trip_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame -) -> None: +def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame) -> None: """ Round-trip test for one-hot-encoding and then decoding a data frame. """ @@ -151,21 +135,17 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No assert (series_round_trip.z == series.z).all() -def test_round_trip_reverse_data_frame( - optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray -) -> None: +def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, + one_hot_data_frame: npt.NDArray) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ - round_trip = optimizer._to_1hot( - config=optimizer._from_1hot(config=one_hot_data_frame) - ) + round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame)) assert round_trip == pytest.approx(one_hot_data_frame) -def test_round_trip_reverse_series( - optimizer: BaseOptimizer, one_hot_series: npt.NDArray -) -> None: +def test_round_trip_reverse_series(optimizer: BaseOptimizer, + one_hot_series: npt.NDArray) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index ad9ae51d23..0b9d624a7a 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -20,15 +20,10 @@ _LOG = logging.getLogger(__name__) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_multi_target_opt_wrong_weights( - optimizer_class: Type[BaseOptimizer], kwargs: dict -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: """ Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. @@ -36,31 +31,23 @@ def test_multi_target_opt_wrong_weights( with pytest.raises(ValueError): optimizer_class( parameter_space=CS.ConfigurationSpace(seed=SEED), - optimization_targets=["main_score", "other_score"], + optimization_targets=['main_score', 'other_score'], objective_weights=[1], - **kwargs, + **kwargs ) -@pytest.mark.parametrize( - ("objective_weights"), - [ - [2, 1], - [0.5, 0.5], - None, - ], -) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_multi_target_opt( - objective_weights: Optional[List[float]], - optimizer_class: Type[BaseOptimizer], - kwargs: dict, -) -> None: +@pytest.mark.parametrize(('objective_weights'), [ + [2, 1], + [0.5, 0.5], + None, +]) +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_multi_target_opt(objective_weights: Optional[List[float]], + optimizer_class: Type[BaseOptimizer], + kwargs: dict) -> None: """ Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. @@ -69,25 +56,21 @@ def test_multi_target_opt( def objective(point: pd.DataFrame) -> pd.DataFrame: # mix of hyperparameters, optimal is to select the highest possible - return pd.DataFrame( - { - "main_score": point.x + point.y, - "other_score": point.x**2 + point.y**2, - } - ) + return pd.DataFrame({ + "main_score": point.x + point.y, + "other_score": point.x ** 2 + point.y ** 2, + }) input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5) - ) + CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0) - ) + CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) optimizer = optimizer_class( parameter_space=input_space, - optimization_targets=["main_score", "other_score"], + optimization_targets=['main_score', 'other_score'], objective_weights=objective_weights, **kwargs, ) @@ -102,28 +85,27 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {"x", "y"} + assert set(suggestion.columns) == {'x', 'y'} # Check suggestion values are the expected dtype assert isinstance(suggestion.x.iloc[0], np.integer) assert isinstance(suggestion.y.iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() - ) + optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. observation = objective(suggestion) assert isinstance(observation, pd.DataFrame) - assert set(observation.columns) == {"main_score", "other_score"} + assert set(observation.columns) == {'main_score', 'other_score'} optimizer.register(configs=suggestion, scores=observation) (best_config, best_score, best_context) = optimizer.get_best_observations() assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y"} - assert set(best_score.columns) == {"main_score", "other_score"} + assert set(best_config.columns) == {'x', 'y'} + assert set(best_score.columns) == {'main_score', 'other_score'} assert best_config.shape == (1, 2) assert best_score.shape == (1, 2) @@ -131,7 +113,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y"} - assert set(all_scores.columns) == {"main_score", "other_score"} + assert set(all_configs.columns) == {'x', 'y'} + assert set(all_scores.columns) == {'main_score', 'other_score'} assert all_configs.shape == (max_iterations, 2) assert all_scores.shape == (max_iterations, 2) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index c923a4f4bc..5fd28ca1ed 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -32,24 +32,20 @@ _LOG.setLevel(logging.DEBUG) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_create_optimizer_and_suggest( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: """ Test that we can create an optimizer and get a suggestion from it. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) assert optimizer is not None @@ -66,17 +62,11 @@ def test_create_optimizer_and_suggest( optimizer.register_pending(configs=suggestion, metadata=metadata) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_basic_interface_toy_problem( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers. """ @@ -87,15 +77,17 @@ def test_basic_interface_toy_problem( if optimizer_class == OptimizerType.SMAC.value: # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. - kwargs["max_trials"] = max_iterations * 2 + kwargs['max_trials'] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: - return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)}) + return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)}) # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) with pytest.raises(ValueError, match="No observations"): @@ -108,14 +100,12 @@ def objective(x: pd.Series) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {"x", "y", "z"} + assert set(suggestion.columns) == {'x', 'y', 'z'} # check that suggestion is in the space - configuration = CS.Configuration( - optimizer.parameter_space, suggestion.iloc[0].to_dict() - ) + configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict()) # Raises an error if outside of configuration space configuration.is_valid_configuration() - observation = objective(suggestion["x"]) + observation = objective(suggestion['x']) assert isinstance(observation, pd.DataFrame) optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -123,8 +113,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y", "z"} - assert set(best_score.columns) == {"score"} + assert set(best_config.columns) == {'x', 'y', 'z'} + assert set(best_score.columns) == {'score'} assert best_config.shape == (1, 3) assert best_score.shape == (1, 1) assert best_score.score.iloc[0] < -5 @@ -133,8 +123,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y", "z"} - assert set(all_scores.columns) == {"score"} + assert set(all_configs.columns) == {'x', 'y', 'z'} + assert set(all_scores.columns) == {'score'} assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) @@ -147,36 +137,27 @@ def objective(x: pd.Series) -> pd.DataFrame: assert pred_all.shape == (20,) -@pytest.mark.parametrize( - ("optimizer_type"), - [ - # Enumerate all supported Optimizers - # *[member for member in OptimizerType], - *list(OptimizerType), - ], -) +@pytest.mark.parametrize(('optimizer_type'), [ + # Enumerate all supported Optimizers + # *[member for member in OptimizerType], + *list(OptimizerType), +]) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ - assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member - - -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ], -) -def test_create_optimizer_with_factory_method( - configuration_space: CS.ConfigurationSpace, - optimizer_type: Optional[OptimizerType], - kwargs: Optional[dict], -) -> None: + assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member + + +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument +]) +def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, + optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: """ Test that we can create an optimizer via a factory. """ @@ -185,13 +166,13 @@ def test_create_optimizer_with_factory_method( if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -207,25 +188,17 @@ def test_create_optimizer_with_factory_method( assert myrepr.startswith(optimizer_type.value.__name__) -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ( - OptimizerType.SMAC, - { - # Test with default config. - "use_default_config": True, - # 'n_random_init': 10, - }, - ), - ], -) -def test_optimizer_with_llamatune( - optimizer_type: OptimizerType, kwargs: Optional[dict] -) -> None: +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + (OptimizerType.SMAC, { + # Test with default config. + 'use_default_config': True, + # 'n_random_init': 10, + }), +]) +def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers with llamatune space adapter. """ @@ -242,12 +215,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=1234) # Add two continuous inputs - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="x", lower=0, upper=3) - ) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="y", lower=0, upper=3) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3)) # Initialize an optimizer that uses LlamaTune space adapter space_adapter_kwargs = { @@ -270,7 +239,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: llamatune_optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=llamatune_optimizer_kwargs, space_adapter_type=SpaceAdapterType.LLAMATUNE, @@ -279,19 +248,16 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Initialize an optimizer that uses the original space optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=optimizer_kwargs, ) assert optimizer is not None assert llamatune_optimizer is not None - assert ( - optimizer.optimizer_parameter_space - != llamatune_optimizer.optimizer_parameter_space - ) + assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space llamatune_n_random_init = 0 - opt_n_random_init = int(kwargs.get("n_random_init", 0)) + opt_n_random_init = int(kwargs.get('n_random_init', 0)) if optimizer_type == OptimizerType.SMAC: assert isinstance(optimizer, SmacOptimizer) assert isinstance(llamatune_optimizer, SmacOptimizer) @@ -312,48 +278,37 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() - _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( - 3.0, rel=1e-3 - ) # optimizer explores 1-dimensional space + _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0] + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space observation = objective(suggestion) - llamatune_optimizer.register( - configs=suggestion, scores=observation, metadata=metadata - ) + llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) # Retrieve best observations best_observation = optimizer.get_best_observations() llamatune_best_observation = llamatune_optimizer.get_best_observations() - for best_config, best_score, best_context in ( - best_observation, - llamatune_best_observation, - ): + for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation): assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y"} - assert set(best_score.columns) == {"score"} + assert set(best_config.columns) == {'x', 'y'} + assert set(best_score.columns) == {'score'} (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that - assert ( - best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] - or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] - ) + assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \ + best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] # Retrieve and check all observations - for all_configs, all_scores, all_contexts in ( - optimizer.get_observations(), - llamatune_optimizer.get_observations(), - ): + for (all_configs, all_scores, all_contexts) in ( + optimizer.get_observations(), llamatune_optimizer.get_observations()): assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y"} - assert set(all_scores.columns) == {"score"} + assert set(all_configs.columns) == {'x', 'y'} + assert set(all_scores.columns) == {'score'} assert len(all_configs) == num_iters assert len(all_scores) == num_iters @@ -365,13 +320,12 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( - BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] -) +optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract] + pkg_name='mlos_core') assert optimizer_subclasses -@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses) +@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: """ Test that all optimizer classes are listed in the OptimizerType enum. @@ -380,19 +334,14 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: assert optimizer_class in optimizer_type_classes -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ], -) -def test_mixed_numerics_type_input_space_types( - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] -) -> None: +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument +]) +def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -406,23 +355,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5) - ) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0) - ) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -436,14 +381,12 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) - assert (suggestion.columns == ["x", "y"]).all() + assert (suggestion.columns == ['x', 'y']).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion["x"].iloc[0], np.integer) - assert isinstance(suggestion["y"].iloc[0], np.floating) + assert isinstance(suggestion['x'].iloc[0], np.integer) + assert isinstance(suggestion['y'].iloc[0], np.floating) # Check that suggestion is in the space - test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() - ) + test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 13a28d242d..37b8aa3a69 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -20,33 +20,22 @@ def test_identity_adapter() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100) - ) + CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) - ) + CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) adapter = IdentityAdapter(orig_parameter_space=input_space) num_configs = 10 - for sampled_config in input_space.sample_configuration( - size=num_configs - ): # pylint: disable=not-an-iterable # (false positive) - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) assert target_config_df.equals(sampled_config_df) - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == sampled_config orig_config_df = adapter.transform(target_config_df) assert orig_config_df.equals(sampled_config_df) - orig_config = CS.Configuration( - adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) assert orig_config == sampled_config diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index d0dfcb7691..84dcd4e5c0 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -30,64 +30,34 @@ def construct_parameter_space( for idx in range(n_continuous_params): input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64) - ) + CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64)) for idx in range(n_integer_params): input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256) - ) + CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256)) for idx in range(n_categorical_params): input_space.add_hyperparameter( - CS.CategoricalHyperparameter( - name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)] - ) - ) + CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)])) return input_space -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - # Mix of all three types - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - }, - ) - ] - ), -) -def test_num_low_dims( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's low-to-high space projection method. """ @@ -96,7 +66,8 @@ def test_num_low_dims( # Number of target parameter space dimensions should be fewer than those of the original space with pytest.raises(ValueError): LlamaTuneAdapter( - orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys())) + orig_parameter_space=input_space, + num_low_dims=len(list(input_space.keys())) ) # Enable only low-dimensional space projections @@ -104,53 +75,35 @@ def test_num_low_dims( orig_parameter_space=input_space, num_low_dims=num_target_space_dims, special_param_values=None, - max_unique_values_per_param=None, + max_unique_values_per_param=None ) sampled_configs = adapter.target_parameter_space.sample_configuration(size=100) - for ( - sampled_config - ) in sampled_configs: # pylint: disable=not-an-iterable # (false positive) + for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive) # Transform low-dim config to high-dim point/config - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid - orig_config = CS.Configuration( - input_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(input_space, values=orig_config_df.iloc[0].to_dict()) input_space.check_configuration(orig_config) # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == sampled_config # Try inverse projection (i.e., high-to-low) for previously unseen configs - unseen_sampled_configs = adapter.target_parameter_space.sample_configuration( - size=25 - ) - for ( - unseen_sampled_config - ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) - if ( - unseen_sampled_config in sampled_configs - ): # pylint: disable=unsupported-membership-test # (false positive) + unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25) + for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) + if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive) continue - unseen_sampled_config_df = pd.DataFrame( - [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()) - ) + unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())) with pytest.raises(ValueError): - _ = adapter.inverse_transform( - unseen_sampled_config_df - ) # pylint: disable=redefined-variable-type + _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type def test_special_parameter_values_validation() -> None: @@ -159,20 +112,15 @@ def test_special_parameter_values_validation() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.CategoricalHyperparameter( - name="str", choices=[f"choice_{idx}" for idx in range(5)] - ) - ) + CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)])) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100) - ) + CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100)) # Only UniformIntegerHyperparameters are currently supported with pytest.raises(NotImplementedError): - special_param_values_dict_1 = {"str": "choice_1"} + special_param_values_dict_1 = {'str': 'choice_1'} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -181,7 +129,7 @@ def test_special_parameter_values_validation() -> None: ) with pytest.raises(NotImplementedError): - special_param_values_dict_2 = {"cont": -1} + special_param_values_dict_2 = {'cont': -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -190,8 +138,8 @@ def test_special_parameter_values_validation() -> None: ) # Special value should belong to parameter value domain - with pytest.raises(ValueError, match="value domain"): - special_param_values_dict = {"int": -1} + with pytest.raises(ValueError, match='value domain'): + special_param_values_dict = {'int': -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -201,17 +149,15 @@ def test_special_parameter_values_validation() -> None: # Invalid dicts; ValueError should be thrown invalid_special_param_values_dicts: List[Dict[str, Any]] = [ - {"int-Q": 0}, # parameter does not exist - {"int": {0: 0.2}}, # invalid definition - {"int": 0.2}, # invalid parameter value - { - "int": (0.4, 0) - }, # (biasing %, special value) instead of (special value, biasing %) - {"int": [0, 0]}, # duplicate special values - {"int": []}, # empty list - {"int": [{0: 0.2}]}, - {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct - {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values + {'int-Q': 0}, # parameter does not exist + {'int': {0: 0.2}}, # invalid definition + {'int': 0.2}, # invalid parameter value + {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) + {'int': [0, 0]}, # duplicate special values + {'int': []}, # empty list + {'int': [{0: 0.2}]}, + {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct + {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values ] for spv_dict in invalid_special_param_values_dicts: with pytest.raises(ValueError): @@ -224,13 +170,13 @@ def test_special_parameter_values_validation() -> None: # Biasing percentage of special value(s) are invalid invalid_special_param_values_dicts = [ - {"int": (0, 1.1)}, # >1 probability - {"int": (0, 0)}, # Zero probability - {"int": (0, -0.1)}, # Negative probability - {"int": (0, 20)}, # 2,000% instead of 20% - {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% - {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100% - {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. + {'int': (0, 1.1)}, # >1 probability + {'int': (0, 0)}, # Zero probability + {'int': (0, -0.1)}, # Negative probability + {'int': (0, 20)}, # 2,000% instead of 20% + {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% + {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100% + {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. ] for spv_dict in invalid_special_param_values_dicts: @@ -243,34 +189,24 @@ def test_special_parameter_values_validation() -> None: ) -def gen_random_configs( - adapter: LlamaTuneAdapter, num_configs: int -) -> Iterator[CS.Configuration]: - for sampled_config in adapter.target_parameter_space.sample_configuration( - size=num_configs - ): +def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]: + for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs): # Transform low-dim config to high-dim config - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) orig_config_df = adapter.transform(sampled_config_df) - orig_config = CS.Configuration( - adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) yield orig_config -def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex +def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex """ Tests LlamaTune's special parameter values biasing methodology """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100)) num_configs = 400 bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE @@ -278,10 +214,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp # Single parameter; single special value special_param_value_dicts: List[Dict[str, Any]] = [ - {"int_1": 0}, - {"int_1": (0, bias_percentage)}, - {"int_1": [0]}, - {"int_1": [(0, bias_percentage)]}, + {'int_1': 0}, + {'int_1': (0, bias_percentage)}, + {'int_1': [0]}, + {'int_1': [(0, bias_percentage)]} ] for spv_dict in special_param_value_dicts: @@ -293,18 +229,13 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp ) special_value_occurrences = sum( - 1 - for config in gen_random_configs(adapter, num_configs) - if config["int_1"] == 0 - ) - assert (1 - eps) * int( - num_configs * bias_percentage - ) <= special_value_occurrences + 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0) + assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences # Single parameter; multiple special values special_param_value_dicts = [ - {"int_1": [0, 1]}, - {"int_1": [(0, bias_percentage), (1, bias_percentage)]}, + {'int_1': [0, 1]}, + {'int_1': [(0, bias_percentage), (1, bias_percentage)]} ] for spv_dict in special_param_value_dicts: @@ -317,22 +248,18 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp special_values_occurrences = {0: 0, 1: 0} for config in gen_random_configs(adapter, num_configs): - if config["int_1"] == 0: + if config['int_1'] == 0: special_values_occurrences[0] += 1 - elif config["int_1"] == 1: + elif config['int_1'] == 1: special_values_occurrences[1] += 1 - assert (1 - eps) * int( - num_configs * bias_percentage - ) <= special_values_occurrences[0] - assert (1 - eps) * int( - num_configs * bias_percentage - ) <= special_values_occurrences[1] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[1] # Multiple parameters; multiple special values; different biasing percentage spv_dict = { - "int_1": [(0, bias_percentage), (1, bias_percentage / 2)], - "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)], + 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)], + 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)] } adapter = LlamaTuneAdapter( orig_parameter_space=input_space, @@ -342,32 +269,24 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp ) special_values_instances: Dict[str, Dict[int, int]] = { - "int_1": {0: 0, 1: 0}, - "int_2": {2: 0, 100: 0}, + 'int_1': {0: 0, 1: 0}, + 'int_2': {2: 0, 100: 0}, } for config in gen_random_configs(adapter, num_configs): - if config["int_1"] == 0: - special_values_instances["int_1"][0] += 1 - elif config["int_1"] == 1: - special_values_instances["int_1"][1] += 1 - - if config["int_2"] == 2: - special_values_instances["int_2"][2] += 1 - elif config["int_2"] == 100: - special_values_instances["int_2"][100] += 1 - - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances[ - "int_1" - ][0] - assert (1 - eps) * int( - num_configs * bias_percentage / 2 - ) <= special_values_instances["int_1"][1] - assert (1 - eps) * int( - num_configs * bias_percentage / 2 - ) <= special_values_instances["int_2"][2] - assert (1 - eps) * int( - num_configs * bias_percentage * 1.5 - ) <= special_values_instances["int_2"][100] + if config['int_1'] == 0: + special_values_instances['int_1'][0] += 1 + elif config['int_1'] == 1: + special_values_instances['int_1'][1] += 1 + + if config['int_2'] == 2: + special_values_instances['int_2'][2] += 1 + elif config['int_2'] == 100: + special_values_instances['int_2'][100] += 1 + + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2] + assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100] def test_max_unique_values_per_param() -> None: @@ -377,25 +296,17 @@ def test_max_unique_values_per_param() -> None: # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5) - ) + CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) - ) + CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048) - ) + CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048)) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) - ) + CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) input_space.add_hyperparameter( - CS.CategoricalHyperparameter( - name="str_2", choices=[f"choice_{idx}" for idx in range(10)] - ) - ) + CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)])) # Restrict the number of unique parameter values num_configs = 200 @@ -408,9 +319,7 @@ def test_max_unique_values_per_param() -> None: ) # Keep track of unique values generated for each parameter - unique_values_dict: Dict[str, set] = { - param: set() for param in list(input_space.keys()) - } + unique_values_dict: Dict[str, set] = {param: set() for param in list(input_space.keys())} for config in gen_random_configs(adapter, num_configs): for param, value in config.items(): unique_values_dict[param].add(value) @@ -420,48 +329,23 @@ def test_max_unique_values_per_param() -> None: assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - # Mix of all three types - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - }, - ) - ] - ), -) -def test_approx_inverse_mapping( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse. """ @@ -476,11 +360,9 @@ def test_approx_inverse_mapping( use_approximate_reverse_mapping=False, ) - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.raises(ValueError): - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) _ = adapter.inverse_transform(sampled_config_df) # Enable low-dimensional space projection *and* reverse mapping @@ -493,63 +375,41 @@ def test_approx_inverse_mapping( ) # Warning should be printed the first time - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.warns(UserWarning): - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) adapter.target_parameter_space.check_configuration(target_config) # Test inverse transform with 100 random configs for _ in range(100): - sampled_config = input_space.sample_configuration() # size=1) - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config = input_space.sample_configuration() # size=1) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) adapter.target_parameter_space.check_configuration(target_config) -@pytest.mark.parametrize( - ("num_low_dims", "special_param_values", "max_unique_values_per_param"), - ( - [ - (num_low_dims, special_param_values, max_unique_values_per_param) - for num_low_dims in (8, 16) - for special_param_values in ( - {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]}, - { - "int_1": (-1, 0.1), - "int_2": -1, - "int_3": (-1, 0.3), - "int_4": [(-1, 0.1), (0, 0.2)], - }, - ) - for max_unique_values_per_param in (50, 250) - ] - ), -) -def test_llamatune_pipeline( - num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int -) -> None: +@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([ + (num_low_dims, special_param_values, max_unique_values_per_param) + for num_low_dims in (8, 16) + for special_param_values in ( + {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]}, + {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]}, + ) + for max_unique_values_per_param in (50, 250) +])) +def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: """ Tests LlamaTune space adapter when all components are active. """ # pylint: disable=too-many-locals # Define config space with a mix of different parameter types - input_space = construct_parameter_space( - n_continuous_params=10, n_integer_params=10, n_categorical_params=5 - ) + input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=num_low_dims, @@ -559,29 +419,23 @@ def test_llamatune_pipeline( special_value_occurrences = { param: {special_value: 0 for special_value, _ in tuples_list} - for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access } unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()} num_configs = 1000 - for config in adapter.target_parameter_space.sample_configuration( - size=num_configs - ): # pylint: disable=not-an-iterable + for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # Transform low-dim config to high-dim point/config sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys())) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid - orig_config = CS.Configuration( - input_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(input_space, values=orig_config_df.iloc[0].to_dict()) input_space.check_configuration(orig_config) # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == config for param, value in orig_config.items(): @@ -595,66 +449,35 @@ def test_llamatune_pipeline( # Ensure that occurrences of special values do not significantly deviate from expected eps = 0.2 - for ( - param, - tuples_list, - ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access for value, bias_percentage in tuples_list: - assert (1 - eps) * int( - num_configs * bias_percentage - ) <= special_value_occurrences[param][value] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value] # Ensure that number of unique values is less than the maximum number allowed for _, unique_values in unique_values_dict.items(): assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - { - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor - ) - }, - # Mix of all three types - { - "n_continuous_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_integer_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - "n_categorical_params": int( - num_target_space_dims * num_orig_space_factor / 3 - ), - }, - ) - ] - ), -) -def test_deterministic_behavior_for_same_seed( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: """ Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ - def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) @@ -667,14 +490,8 @@ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: use_approximate_reverse_mapping=False, ) - sample_configs: List[CS.Configuration] = ( - adapter.target_parameter_space.sample_configuration(size=100) - ) + sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100) return sample_configs - assert generate_target_param_space_configs( - 42 - ) == generate_target_param_space_configs(42) - assert generate_target_param_space_configs( - 1234 - ) != generate_target_param_space_configs(42) + assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42) + assert generate_target_param_space_configs(1234) != generate_target_param_space_configs(42) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index c2edd18b69..5390f97c5f 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -23,51 +23,39 @@ from mlos_core.tests import get_all_concrete_subclasses -@pytest.mark.parametrize( - ("space_adapter_type"), - [ - # Enumerate all supported SpaceAdapters - # *[member for member in SpaceAdapterType], - *list(SpaceAdapterType), - ], -) +@pytest.mark.parametrize(('space_adapter_type'), [ + # Enumerate all supported SpaceAdapters + # *[member for member in SpaceAdapterType], + *list(SpaceAdapterType), +]) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] + assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] -@pytest.mark.parametrize( - ("space_adapter_type", "kwargs"), - [ - # Default space adapter - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in SpaceAdapterType], - ], -) -def test_create_space_adapter_with_factory_method( - space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] -) -> None: +@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [ + # Default space adapter + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in SpaceAdapterType], +]) +def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="x", lower=0, upper=1) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="y", lower=0, upper=1) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1)) # Adjust some kwargs for specific space adapters if space_adapter_type is SpaceAdapterType.LLAMATUNE: if kwargs is None: kwargs = {} - kwargs.setdefault("num_low_dims", 1) + kwargs.setdefault('num_low_dims', 1) space_adapter: BaseSpaceAdapter if space_adapter_type is None: @@ -85,25 +73,21 @@ def test_create_space_adapter_with_factory_method( assert space_adapter is not None assert space_adapter.orig_parameter_space is not None myrepr = repr(space_adapter) - assert myrepr.startswith( - space_adapter_type.value.__name__ - ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}" + assert myrepr.startswith(space_adapter_type.value.__name__), \ + f"Expected {space_adapter_type.value.__name__} but got {myrepr}" # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses( - BaseSpaceAdapter, pkg_name="mlos_core" -) # type: ignore[type-abstract] +space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \ + get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract] assert space_adapter_subclasses -@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses) +@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: """ Test that all space adapter classes are listed in the SpaceAdapterType enum. """ - space_adapter_type_classes = { - space_adapter_type.value for space_adapter_type in SpaceAdapterType - } + space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index 20666df721..dee9251652 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -41,9 +41,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None: assert np.isclose(frequencies.sum(), 1) _f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies) - assert np.isclose(kurtosis, -1.2, atol=0.1) - assert p_value > 0.3 - assert f_p_value > 0.5 + assert np.isclose(kurtosis, -1.2, atol=.1) + assert p_value > .3 + assert f_p_value > .5 def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None: @@ -70,20 +70,17 @@ def invalid_conversion_function(*args: Any) -> NoReturn: """ A quick dummy function for the base class to make pylint happy. """ - raise NotImplementedError("subclass must override conversion_function") + raise NotImplementedError('subclass must override conversion_function') class BaseConversion(metaclass=ABCMeta): """ Base class for testing optimizer space conversions. """ - conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod - def sample( - self, config_space: OptimizerSpace, n_samples: int = 1 - ) -> OptimizerParam: + def sample(self, config_space: OptimizerSpace, n_samples: int = 1) -> OptimizerParam: """ Sample from the given configuration space. @@ -131,12 +128,8 @@ def test_unsupported_hyperparameter(self) -> None: def test_continuous_bounds(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter("a", lower=100, upper=200) - ) - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter("b", lower=-10, upper=-5) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter("a", lower=100, upper=200)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("b", lower=-10, upper=-5)) converted_space = self.conversion_function(input_space) assert self.get_parameter_names(converted_space) == ["a", "b"] @@ -146,12 +139,8 @@ def test_continuous_bounds(self) -> None: def test_uniform_samples(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter("a", lower=1, upper=5) - ) - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter("c", lower=1, upper=20) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter("a", lower=1, upper=5)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("c", lower=1, upper=20)) converted_space = self.conversion_function(input_space) np.random.seed(42) @@ -161,16 +150,14 @@ def test_uniform_samples(self) -> None: assert_is_uniform(uniform) # Check that we get both ends of the sampled range returned to us. - assert input_space["c"].lower in integer_uniform - assert input_space["c"].upper in integer_uniform + assert input_space['c'].lower in integer_uniform + assert input_space['c'].upper in integer_uniform # integer uniform assert_is_uniform(integer_uniform) def test_uniform_categorical(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.CategoricalHyperparameter("c", choices=["foo", "bar"]) - ) + input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"])) converted_space = self.conversion_function(input_space) points = self.sample(converted_space, n_samples=100) counts = self.categorical_counts(points) @@ -178,13 +165,13 @@ def test_uniform_categorical(self) -> None: assert 35 < counts[1] < 65 def test_weighted_categorical(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') def test_log_int_spaces(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') def test_log_float_spaces(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') class TestFlamlConversion(BaseConversion): @@ -197,12 +184,10 @@ class TestFlamlConversion(BaseConversion): def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override] assert isinstance(config_space, dict) assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) - ret: npt.NDArray = np.array( - [domain.sample(size=n_samples) for domain in config_space.values()] - ).T + ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T return ret - def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] + def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] assert isinstance(config_space, dict) ret: List[str] = list(config_space.keys()) return ret @@ -214,26 +199,16 @@ def categorical_counts(self, points: npt.NDArray) -> npt.NDArray: def test_dimensionality(self) -> None: input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter("a", lower=1, upper=10) - ) - input_space.add_hyperparameter( - CS.CategoricalHyperparameter("b", choices=["bof", "bum"]) - ) - input_space.add_hyperparameter( - CS.CategoricalHyperparameter("c", choices=["foo", "bar"]) - ) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("a", lower=1, upper=10)) + input_space.add_hyperparameter(CS.CategoricalHyperparameter("b", choices=["bof", "bum"])) + input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"])) output_space = configspace_to_flaml_space(input_space) assert len(output_space) == 3 def test_weighted_categorical(self) -> None: np.random.seed(42) input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.CategoricalHyperparameter( - "c", choices=["foo", "bar"], weights=[0.9, 0.1] - ) - ) + input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])) with pytest.raises(ValueError, match="non-uniform"): configspace_to_flaml_space(input_space) @@ -242,9 +217,7 @@ def test_log_int_spaces(self) -> None: np.random.seed(42) # integer is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True) - ) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -262,9 +235,7 @@ def test_log_float_spaces(self) -> None: # continuous is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -274,6 +245,6 @@ def test_log_float_spaces(self) -> None: assert_is_log_uniform(float_log_uniform) -if __name__ == "__main__": +if __name__ == '__main__': # For attaching debugger debugging: pytest.main(["-vv", "-k", "test_log_int_spaces", __file__]) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index e4533558a9..df0e144535 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -28,9 +28,7 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: return pd.DataFrame([dict(config)]) -def normalize_config( - config_space: ConfigurationSpace, config: Union[Configuration, dict] -) -> Configuration: +def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. @@ -49,13 +47,10 @@ def normalize_config( cs_config: Configuration A valid ConfigSpace configuration with inactive parameters removed. """ - cs_config = Configuration( - config_space, values=config, allow_inactive_with_values=True - ) + cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True) return Configuration( - config_space, - values={ + config_space, values={ key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config) - }, + } ) diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index f946f94aa4..2362de7083 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 4a76b78020..fed376d1af 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -21,24 +21,21 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning( - f"setuptools_scm failed to find git version, using version from version.py: {e}" - ) + warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") # A simple routine to read and adjust the README.md for this module into a format @@ -52,59 +49,53 @@ # we return nothing when the file is not available. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } -extra_requires: Dict[str, List[str]] = ( - { # pylint: disable=consider-using-namedtuple-or-dataclass - "flaml": ["flaml[blendsearch]"], - "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0 - } -) +extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass + 'flaml': ['flaml[blendsearch]'], + 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0 +} # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', ] setup( version=VERSION, install_requires=[ - "scikit-learn>=1.2", - "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released - "scipy>=1.3.2", - "numpy>=1.24", - "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 - 'pandas >= 2.2.0;python_version>="3.9"', - 'Bottleneck > 1.3.5;python_version>="3.9"', + 'scikit-learn>=1.2', + 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released + 'scipy>=1.3.2', + 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710 + 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - "ConfigSpace>=0.7.1", + 'ConfigSpace>=0.7.1', ], extras_require=extra_requires, - **_get_long_desc_from_readme( - "https://github.com/microsoft/MLOS/tree/main/mlos_core" - ), + **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"), ) diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index b7a88957f3..2390554e1e 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -23,7 +23,7 @@ class MlosVizMethod(Enum): """ DABL = "dabl" - AUTO = DABL # use dabl as the current default + AUTO = DABL # use dabl as the current default def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: @@ -39,21 +39,17 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel - mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") -def plot( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - plotter_method: MlosVizMethod = MlosVizMethod.AUTO, - filter_warnings: bool = True, - **kwargs: Any, -) -> None: +def plot(exp_data: Optional[ExperimentData] = None, *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + plotter_method: MlosVizMethod = MlosVizMethod.AUTO, + filter_warnings: bool = True, + **kwargs: Any) -> None: """ Plots the results of the experiment. @@ -81,13 +77,10 @@ def plot( (results_df, _obj_cols) = expand_results_data_args(exp_data, results_df, objectives) base.plot_optimizer_trends(exp_data, results_df=results_df, objectives=objectives) - base.plot_top_n_configs( - exp_data, results_df=results_df, objectives=objectives, **kwargs - ) + base.plot_top_n_configs(exp_data, results_df=results_df, objectives=objectives, **kwargs) if MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel - mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 572759f816..15358b0862 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -20,7 +20,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args -_SEABORN_VERS = version("seaborn") +_SEABORN_VERS = version('seaborn') def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: @@ -30,7 +30,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: Note: this only works with non-positional kwargs (e.g., those after a * arg). """ target_kwargs = {} - for kword in target.__kwdefaults__: # or {} # intentionally omitted for now + for kword in target.__kwdefaults__: # or {} # intentionally omitted for now if kword in kwargs: target_kwargs[kword] = kwargs[kword] return target_kwargs @@ -42,19 +42,14 @@ def ignore_plotter_warnings() -> None: adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) - if _SEABORN_VERS <= "0.13.1": - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - module="seaborn", # but actually comes from pandas - message="is_categorical_dtype is deprecated and will be removed in a future version.", - ) + if _SEABORN_VERS <= '0.13.1': + warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas + message="is_categorical_dtype is deprecated and will be removed in a future version.") -def _add_groupby_desc_column( - results_df: pandas.DataFrame, - groupby_columns: Optional[List[str]] = None, -) -> Tuple[pandas.DataFrame, List[str], str]: +def _add_groupby_desc_column(results_df: pandas.DataFrame, + groupby_columns: Optional[List[str]] = None, + ) -> Tuple[pandas.DataFrame, List[str], str]: """ Adds a group descriptor column to the results_df. @@ -72,19 +67,17 @@ def _add_groupby_desc_column( if groupby_columns is None: groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"] groupby_column = ",".join(groupby_columns) - results_df[groupby_column] = ( - results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1) - ) # pylint: disable=unnecessary-lambda + results_df[groupby_column] = results_df[groupby_columns].astype(str).apply( + lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda groupby_columns.append(groupby_column) return (results_df, groupby_columns, groupby_column) -def augment_results_df_with_config_trial_group_stats( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - requested_result_cols: Optional[Iterable[str]] = None, -) -> pandas.DataFrame: +def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + requested_result_cols: Optional[Iterable[str]] = None, + ) -> pandas.DataFrame: # pylint: disable=too-complex """ Add a number of useful statistical measure columns to the results dataframe. @@ -141,50 +134,30 @@ def augment_results_df_with_config_trial_group_stats( raise ValueError(f"Not enough data: {len(results_groups)}") if requested_result_cols is None: - result_cols = set( - col - for col in results_df.columns - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) - ) + result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)) else: - result_cols = set( - col - for col in requested_result_cols - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) - and col in results_df.columns - ) - result_cols.update( - set( - ExperimentData.RESULT_COLUMN_PREFIX + col - for col in requested_result_cols - if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns - ) - ) + result_cols = set(col for col in requested_result_cols + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns) + result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols + if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns)) def compute_zscore_for_group_agg( - results_groups_perf: "SeriesGroupBy", - stats_df: pandas.DataFrame, - result_col: str, - agg: Union[Literal["mean"], Literal["var"], Literal["std"]], + results_groups_perf: "SeriesGroupBy", + stats_df: pandas.DataFrame, + result_col: str, + agg: Union[Literal["mean"], Literal["var"], Literal["std"]] ) -> None: - results_groups_perf_aggs = results_groups_perf.agg( - agg - ) # TODO: avoid recalculating? + results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() - stats_df[result_col + f".{agg}_zscore"] = ( - stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"] - ) / stats_df[result_col + f".{agg}_stddev"] - stats_df.drop( - columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], - inplace=True, - ) + stats_df[result_col + f".{agg}_zscore"] = \ + (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \ + / stats_df[result_col + f".{agg}_stddev"] + stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True) augmented_results_df = results_df - augmented_results_df["tunable_config_trial_group_size"] = results_groups[ - "trial_id" - ].transform("count") + augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count") for result_col in result_cols: if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX): continue @@ -197,31 +170,26 @@ def compute_zscore_for_group_agg( continue results_groups_perf = results_groups[result_col] stats_df = pandas.DataFrame() - stats_df[result_col + ".mean"] = results_groups_perf.transform( - "mean", numeric_only=True - ) + stats_df[result_col + ".mean"] = results_groups_perf.transform("mean", numeric_only=True) stats_df[result_col + ".var"] = results_groups_perf.transform("var") - stats_df[result_col + ".stddev"] = stats_df[result_col + ".var"].apply( - lambda x: x**0.5 - ) + stats_df[result_col + ".stddev"] = stats_df[result_col + ".var"].apply(lambda x: x**0.5) compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var") quantiles = [0.50, 0.75, 0.90, 0.95, 0.99] - for quantile in quantiles: # TODO: can we do this in one pass? + for quantile in quantiles: # TODO: can we do this in one pass? quantile_col = f"{result_col}.p{int(quantile * 100)}" stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile) augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1) return augmented_results_df -def limit_top_n_configs( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - top_n_configs: int = 10, - method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", -) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: +def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + top_n_configs: int = 10, + method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", + ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ Utility function to process the results and determine the best performing @@ -251,9 +219,7 @@ def limit_top_n_configs( raise ValueError(f"Invalid method: {method}") # Prepare the orderby columns. - (results_df, objs_cols) = expand_results_data_args( - exp_data, results_df=results_df, objectives=objectives - ) + (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives) assert isinstance(results_df, pandas.DataFrame) # Augment the results dataframe with some useful stats. @@ -266,19 +232,13 @@ def limit_top_n_configs( # results_df is not None and is in fact a DataFrame, so we periodically assert # it in this func for now. assert results_df is not None - orderby_cols: Dict[str, bool] = { - obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items() - } + orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()} config_id_col = "tunable_config_id" - group_id_col = "tunable_config_trial_group_id" # first trial_id per config group + group_id_col = "tunable_config_trial_group_id" # first trial_id per config group trial_id_col = "trial_id" - default_config_id = ( - results_df[trial_id_col].min() - if exp_data is None - else exp_data.default_tunable_config_id - ) + default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id assert default_config_id is not None, "Failed to determine default config id." # Filter out configs whose variance is too large. @@ -290,20 +250,16 @@ def limit_top_n_configs( singletons_mask = results_df["tunable_config_trial_group_size"] == 1 else: singletons_mask = results_df["tunable_config_trial_group_size"] > 1 - results_df = results_df.loc[ - ( - (results_df[f"{obj_col}.var_zscore"].abs() < 2) - | (singletons_mask) - | (results_df[config_id_col] == default_config_id) - ) - ] + results_df = results_df.loc[( + (results_df[f"{obj_col}.var_zscore"].abs() < 2) + | (singletons_mask) + | (results_df[config_id_col] == default_config_id) + )] assert results_df is not None # Also, filter results that are worse than the default. - default_config_results_df = results_df.loc[ - results_df[config_id_col] == default_config_id - ] - for orderby_col, ascending in orderby_cols.items(): + default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id] + for (orderby_col, ascending) in orderby_cols.items(): default_vals = default_config_results_df[orderby_col].unique() assert len(default_vals) == 1 default_val = default_vals[0] @@ -315,38 +271,29 @@ def limit_top_n_configs( # Now regroup and filter to the top-N configs by their group performance dimensions. assert results_df is not None - group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[ - orderby_cols.keys() - ] - top_n_config_ids: List[int] = ( - group_results_df.sort_values( - by=list(orderby_cols.keys()), ascending=list(orderby_cols.values()) - ) - .head(top_n_configs) - .index.tolist() - ) + group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()] + top_n_config_ids: List[int] = group_results_df.sort_values( + by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist() # Remove the default config if it's included. We'll add it back later. if default_config_id in top_n_config_ids: top_n_config_ids.remove(default_config_id) # Get just the top-n config results. # Sort by the group ids. - top_n_config_results_df = results_df.loc[ - (results_df[config_id_col].isin(top_n_config_ids)) - ].sort_values([group_id_col, config_id_col, trial_id_col]) + top_n_config_results_df = results_df.loc[( + results_df[config_id_col].isin(top_n_config_ids) + )].sort_values([group_id_col, config_id_col, trial_id_col]) # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) - top_n_config_results_df = pandas.concat( - [default_config_results_df, top_n_config_results_df], axis=0 - ) + top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0) return (top_n_config_results_df, top_n_config_ids, orderby_cols) def plot_optimizer_trends( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, ) -> None: """ Plots the optimizer trends for the Experiment. @@ -365,16 +312,12 @@ def plot_optimizer_trends( (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) (results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df) - for objective_column, ascending in obj_cols.items(): + for (objective_column, ascending) in obj_cols.items(): incumbent_column = objective_column + ".incumbent" # Determine the mean of each config trial group to match the box plots. - group_results_df = ( - results_df.groupby(groupby_columns)[objective_column] - .mean() - .reset_index() - .sort_values(groupby_columns) - ) + group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\ + .reset_index().sort_values(groupby_columns) # # Note: technically the optimizer (usually) uses the *first* result for a # given config trial group before moving on to a new config (x-axis), so @@ -388,13 +331,9 @@ def plot_optimizer_trends( # Calculate the incumbent (best seen so far) if ascending: - group_results_df[incumbent_column] = group_results_df[ - objective_column - ].cummin() + group_results_df[incumbent_column] = group_results_df[objective_column].cummin() else: - group_results_df[incumbent_column] = group_results_df[ - objective_column - ].cummax() + group_results_df[incumbent_column] = group_results_df[objective_column].cummax() (_fig, axis) = plt.subplots(figsize=(15, 5)) @@ -416,29 +355,24 @@ def plot_optimizer_trends( ax=axis, ) - plt.yscale("log") + plt.yscale('log') plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")) plt.xlabel("Config Trial Group ID, Config ID") plt.xticks(rotation=90, fontsize=8) - plt.title( - "Optimizer Trends for Experiment: " + exp_data.experiment_id - if exp_data is not None - else "" - ) + plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "") plt.grid() plt.show() # type: ignore[no-untyped-call] -def plot_top_n_configs( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - with_scatter_plot: bool = False, - **kwargs: Any, -) -> None: +def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + with_scatter_plot: bool = False, + **kwargs: Any, + ) -> None: # pylint: disable=too-many-locals """ Plots the top-N configs along with the default config for the given ExperimentData. @@ -466,16 +400,12 @@ def plot_top_n_configs( top_n_config_args["results_df"] = results_df if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives - (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( - exp_data=exp_data, **top_n_config_args - ) + (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args) - (top_n_config_results_df, _groupby_columns, groupby_column) = ( - _add_groupby_desc_column(top_n_config_results_df) - ) + (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 - for orderby_col, ascending in orderby_cols.items(): + for (orderby_col, ascending) in orderby_cols.items(): opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "") (_fig, axis) = plt.subplots() sns.violinplot( @@ -495,12 +425,12 @@ def plot_top_n_configs( plt.grid() (xticks, xlabels) = plt.xticks() # default should be in the first position based on top_n_configs() return - xlabels[0] = "default" # type: ignore[call-overload] - plt.xticks(xticks, xlabels) # type: ignore[arg-type] + xlabels[0] = "default" # type: ignore[call-overload] + plt.xticks(xticks, xlabels) # type: ignore[arg-type] plt.xlabel("Config Trial Group, Config ID") plt.xticks(rotation=90) plt.ylabel(opt_tgt) - plt.yscale("log") + plt.yscale('log') extra_title = "(lower is better)" if ascending else "(lower is better)" plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}") plt.show() # type: ignore[no-untyped-call] diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 9d7f673612..504486a58c 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -15,12 +15,10 @@ from mlos_viz.util import expand_results_data_args -def plot( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, -) -> None: +def plot(exp_data: Optional[ExperimentData] = None, *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + ) -> None: """ Plots the Experiment results data using dabl. @@ -46,51 +44,17 @@ def ignore_plotter_warnings() -> None: """ # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="Could not infer format" - ) - warnings.filterwarnings( - "ignore", - module="dabl", - category=UserWarning, - message="(Dropped|Discarding) .* outliers", - ) - warnings.filterwarnings( - "ignore", - module="dabl", - category=UserWarning, - message="Not plotting highly correlated", - ) - warnings.filterwarnings( - "ignore", - module="dabl", - category=UserWarning, - message="Missing values in target_col have been removed for regression", - ) + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, + message="Missing values in target_col have been removed for regression") from sklearn.exceptions import UndefinedMetricWarning - - warnings.filterwarnings( - "ignore", - module="sklearn", - category=UndefinedMetricWarning, - message="Recall is ill-defined", - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="is_categorical_dtype is deprecated and will be removed in a future version.", - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - module="sklearn", - message="is_sparse is deprecated and will be removed in a future version.", - ) + warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined") + warnings.filterwarnings("ignore", category=DeprecationWarning, + message="is_categorical_dtype is deprecated and will be removed in a future version.") + warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn", + message="is_sparse is deprecated and will be removed in a future version.") from matplotlib._api.deprecation import MatplotlibDeprecationWarning - - warnings.filterwarnings( - "ignore", - category=MatplotlibDeprecationWarning, - module="dabl", - message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", - ) + warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl", + message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed") diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index e5528f9875..06ac4a7664 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -30,5 +30,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) -> warnings.simplefilter("error") random.seed(42) plot(exp_data, filter_warnings=True) - assert mock_show.call_count >= 2 # from the two base plots and anything dabl did - assert mock_boxplot.call_count >= 1 # from anything dabl did + assert mock_show.call_count >= 2 # from the two base plots and anything dabl did + assert mock_boxplot.call_count >= 1 # from anything dabl did diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 4e081193bc..744fe28648 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -41,35 +41,24 @@ def expand_results_data_args( # Prepare the orderby columns. if results_df is None: if exp_data is None: - raise ValueError( - "Must provide either exp_data or both results_df and objectives." - ) + raise ValueError("Must provide either exp_data or both results_df and objectives.") results_df = exp_data.results_df if objectives is None: if exp_data is None: - raise ValueError( - "Must provide either exp_data or both results_df and objectives." - ) + raise ValueError("Must provide either exp_data or both results_df and objectives.") objectives = exp_data.objectives objs_cols: Dict[str, bool] = {} - for opt_tgt, opt_dir in objectives.items(): + for (opt_tgt, opt_dir) in objectives.items(): if opt_dir not in ["min", "max"]: - raise ValueError( - f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}" - ) + raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}") ascending = opt_dir == "min" - if ( - opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) - and opt_tgt in results_df.columns - ): + if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns: objs_cols[opt_tgt] = ascending elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns: objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending else: - raise UserWarning( - f"{opt_tgt} is not a result column for experiment {exp_data}" - ) + raise UserWarning(f"{opt_tgt} is not a result column for experiment {exp_data}") # Note: these copies are important to avoid issues with downstream consumers. # It is more efficient to copy the dataframe than to go back to the original data source. # TODO: However, it should be possible to later fixup the downstream consumers diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index d418ae43c7..607c7cc014 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index d8f6595813..98d12598e1 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -21,24 +21,21 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: warning("setuptools_scm not found, using version from version.py") except LookupError as e: - warning( - f"setuptools_scm failed to find git version, using version from version.py: {e}" - ) + warning(f"setuptools_scm failed to find git version, using version from version.py: {e}") # A simple routine to read and adjust the README.md for this module into a format @@ -50,22 +47,22 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } @@ -73,25 +70,23 @@ def _get_long_desc_from_readme(base_url: str) -> dict: # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', ] setup( version=VERSION, install_requires=[ - "mlos-bench==" + VERSION, - "dabl>=0.2.6", - "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341 + 'mlos-bench==' + VERSION, + 'dabl>=0.2.6', + 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341 ], extras_require=extra_requires, - **_get_long_desc_from_readme( - "https://github.com/microsoft/MLOS/tree/main/mlos_viz" - ), + **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'), ) From 08baf0fd0e7f0e61b994cb56cc43fe70294793da Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:16:44 +0000 Subject: [PATCH 10/54] Revert "adjust line length for other checkers" This reverts commit ebd406eb4338dddf838967223957412f36b6bec0. --- .editorconfig | 4 ---- .pylintrc | 2 +- setup.cfg | 4 ++-- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.editorconfig b/.editorconfig index c2b6ed65db..e984d47595 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,10 +12,6 @@ charset = utf-8 # Note: this is not currently supported by all editors or their editorconfig plugins. max_line_length = 132 -# See Also: black configuration in pyproject.toml -[*.py] -max_line_length = 88 - # Makefiles need tab indentation [{Makefile,*.mk}] indent_style = tab diff --git a/.pylintrc b/.pylintrc index 6b873c5d60..e686070503 100644 --- a/.pylintrc +++ b/.pylintrc @@ -35,7 +35,7 @@ load-plugins= [FORMAT] # Maximum number of characters on a single line. -max-line-length=88 +max-line-length=132 [MESSAGE CONTROL] disable= diff --git a/setup.cfg b/setup.cfg index b1cf391742..9d09b3356d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ count = True ignore = E203,W503,W504 format = pylint # See Also: .editorconfig, .pylintrc -max-line-length = 88 +max-line-length = 132 show-source = True statistics = True @@ -26,7 +26,7 @@ match = .+(? Date: Mon, 8 Jul 2024 16:17:54 +0000 Subject: [PATCH 11/54] make black use a longer line length for now --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 65f1e5a02c..26409fa408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 88 +line-length = 132 target-version = ["py38", "py39", "py310", "py311", "py312"] include = '\.pyi?$' From 857d694d2a9d7756cc80412f10dca5990531f19b Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:29:51 +0000 Subject: [PATCH 12/54] decrease line length to a longer line than black defaults to but still within the pep8 guidelines --- .editorconfig | 3 +++ .pylintrc | 2 +- pyproject.toml | 2 +- setup.cfg | 2 +- 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.editorconfig b/.editorconfig index e984d47595..b31e722644 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,9 @@ charset = utf-8 # Note: this is not currently supported by all editors or their editorconfig plugins. max_line_length = 132 +[{*.py,*.pyi}] +max_line_length = 99 + # Makefiles need tab indentation [{Makefile,*.mk}] indent_style = tab diff --git a/.pylintrc b/.pylintrc index e686070503..c6c512ecb7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -35,7 +35,7 @@ load-plugins= [FORMAT] # Maximum number of characters on a single line. -max-line-length=132 +max-line-length=99 [MESSAGE CONTROL] disable= diff --git a/pyproject.toml b/pyproject.toml index 26409fa408..6673321b6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 132 +line-length = 99 target-version = ["py38", "py39", "py310", "py311", "py312"] include = '\.pyi?$' diff --git a/setup.cfg b/setup.cfg index 9d09b3356d..1f3e7a39d5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ match = .+(? Date: Mon, 8 Jul 2024 16:32:10 +0000 Subject: [PATCH 13/54] reformat with black at line length 99 --- .../fio/scripts/local/process_fio_results.py | 24 +- .../scripts/local/generate_redis_config.py | 12 +- .../scripts/local/process_redis_results.py | 19 +- .../boot/scripts/local/create_new_grub_cfg.py | 10 +- .../scripts/local/generate_grub_config.py | 10 +- .../local/generate_kernel_config_script.py | 5 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 21 +- mlos_bench/mlos_bench/dict_templater.py | 14 +- .../mlos_bench/environments/__init__.py | 15 +- .../environments/base_environment.py | 101 +++-- .../mlos_bench/environments/composite_env.py | 52 ++- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 107 +++-- .../environments/local/local_fileshare_env.py | 67 +-- .../mlos_bench/environments/mock_env.py | 39 +- .../environments/remote/__init__.py | 12 +- .../environments/remote/host_env.py | 29 +- .../environments/remote/network_env.py | 33 +- .../mlos_bench/environments/remote/os_env.py | 36 +- .../environments/remote/remote_env.py | 38 +- .../environments/remote/saas_env.py | 42 +- .../mlos_bench/environments/script_env.py | 40 +- mlos_bench/mlos_bench/event_loop_context.py | 2 +- mlos_bench/mlos_bench/launcher.py | 251 +++++++---- mlos_bench/mlos_bench/optimizers/__init__.py | 8 +- .../mlos_bench/optimizers/base_optimizer.py | 81 ++-- .../optimizers/convert_configspace.py | 114 ++--- .../optimizers/grid_search_optimizer.py | 76 ++-- .../optimizers/mlos_core_optimizer.py | 93 +++-- .../mlos_bench/optimizers/mock_optimizer.py | 26 +- .../optimizers/one_shot_optimizer.py | 12 +- .../optimizers/track_best_optimizer.py | 26 +- mlos_bench/mlos_bench/os_environ.py | 11 +- mlos_bench/mlos_bench/run.py | 5 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 92 ++-- .../mlos_bench/schedulers/sync_scheduler.py | 4 +- mlos_bench/mlos_bench/services/__init__.py | 6 +- .../mlos_bench/services/base_fileshare.py | 43 +- .../mlos_bench/services/base_service.py | 53 ++- .../mlos_bench/services/config_persistence.py | 279 ++++++++----- .../mlos_bench/services/local/__init__.py | 2 +- .../mlos_bench/services/local/local_exec.py | 45 +- .../services/local/temp_dir_context.py | 19 +- .../services/remote/azure/__init__.py | 10 +- .../services/remote/azure/azure_auth.py | 40 +- .../remote/azure/azure_deployment_services.py | 135 +++--- .../services/remote/azure/azure_fileshare.py | 31 +- .../remote/azure/azure_network_services.py | 72 ++-- .../services/remote/azure/azure_saas.py | 111 ++--- .../remote/azure/azure_vm_services.py | 252 ++++++----- .../services/remote/ssh/ssh_fileshare.py | 41 +- .../services/remote/ssh/ssh_host_service.py | 89 ++-- .../services/remote/ssh/ssh_service.py | 126 +++--- .../mlos_bench/services/types/__init__.py | 16 +- .../services/types/config_loader_type.py | 43 +- .../services/types/fileshare_type.py | 8 +- .../services/types/local_exec_type.py | 13 +- .../types/network_provisioner_type.py | 4 +- .../services/types/remote_config_type.py | 3 +- .../services/types/remote_exec_type.py | 5 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 19 +- mlos_bench/mlos_bench/storage/base_storage.py | 115 +++-- .../mlos_bench/storage/base_trial_data.py | 17 +- .../storage/base_tunable_config_data.py | 3 +- .../base_tunable_config_trial_group_data.py | 16 +- mlos_bench/mlos_bench/storage/sql/__init__.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 225 ++++++---- .../mlos_bench/storage/sql/experiment.py | 250 ++++++----- .../mlos_bench/storage/sql/experiment_data.py | 101 +++-- mlos_bench/mlos_bench/storage/sql/schema.py | 48 ++- mlos_bench/mlos_bench/storage/sql/storage.py | 26 +- mlos_bench/mlos_bench/storage/sql/trial.py | 114 ++--- .../mlos_bench/storage/sql/trial_data.py | 74 ++-- .../storage/sql/tunable_config_data.py | 14 +- .../sql/tunable_config_trial_group_data.py | 41 +- .../mlos_bench/storage/storage_factory.py | 8 +- mlos_bench/mlos_bench/storage/util.py | 18 +- mlos_bench/mlos_bench/tests/__init__.py | 34 +- .../mlos_bench/tests/config/__init__.py | 8 +- .../cli/test_load_cli_config_examples.py | 57 ++- .../mlos_bench/tests/config/conftest.py | 14 +- .../test_load_environment_config_examples.py | 58 ++- .../test_load_global_config_examples.py | 8 +- .../test_load_optimizer_config_examples.py | 8 +- .../tests/config/schemas/__init__.py | 55 ++- .../config/schemas/cli/test_cli_schemas.py | 5 +- .../environments/test_environment_schemas.py | 30 +- .../schemas/globals/test_globals_schemas.py | 1 + .../optimizers/test_optimizer_schemas.py | 58 ++- .../schedulers/test_scheduler_schemas.py | 25 +- .../schemas/services/test_services_schemas.py | 31 +- .../schemas/storage/test_storage_schemas.py | 35 +- .../test_tunable_params_schemas.py | 1 + .../test_tunable_values_schemas.py | 1 + .../test_load_service_config_examples.py | 14 +- .../test_load_storage_config_examples.py | 8 +- mlos_bench/mlos_bench/tests/conftest.py | 16 +- .../mlos_bench/tests/environments/__init__.py | 14 +- .../tests/environments/base_env_test.py | 10 +- .../composite_env_service_test.py | 22 +- .../tests/environments/composite_env_test.py | 143 +++---- .../environments/include_tunables_test.py | 40 +- .../tests/environments/local/__init__.py | 20 +- .../local/composite_local_env_test.py | 19 +- .../local/local_env_stdout_test.py | 88 ++-- .../local/local_env_telemetry_test.py | 145 ++++--- .../environments/local/local_env_test.py | 73 ++-- .../environments/local/local_env_vars_test.py | 57 +-- .../local/local_fileshare_env_test.py | 25 +- .../tests/environments/mock_env_test.py | 64 +-- .../tests/environments/remote/test_ssh_env.py | 18 +- .../tests/event_loop_context_test.py | 57 ++- .../tests/launcher_in_process_test.py | 40 +- .../tests/launcher_parse_args_test.py | 123 +++--- .../mlos_bench/tests/launcher_run_test.py | 93 +++-- .../mlos_bench/tests/optimizers/conftest.py | 48 +-- .../optimizers/grid_search_optimizer_test.py | 105 +++-- .../tests/optimizers/llamatune_opt_test.py | 5 +- .../tests/optimizers/mlos_core_opt_df_test.py | 68 +-- .../optimizers/mlos_core_opt_smac_test.py | 78 ++-- .../tests/optimizers/mock_opt_test.py | 67 +-- .../optimizers/opt_bulk_register_test.py | 101 +++-- .../optimizers/toy_optimization_loop_test.py | 16 +- .../mlos_bench/tests/services/__init__.py | 8 +- .../tests/services/config_persistence_test.py | 29 +- .../tests/services/local/__init__.py | 2 +- .../services/local/local_exec_python_test.py | 9 +- .../tests/services/local/local_exec_test.py | 120 +++--- .../tests/services/local/mock/__init__.py | 2 +- .../local/mock/mock_local_exec_service.py | 23 +- .../mlos_bench/tests/services/mock_service.py | 23 +- .../tests/services/remote/__init__.py | 6 +- .../remote/azure/azure_fileshare_test.py | 142 ++++--- .../azure/azure_network_services_test.py | 83 ++-- .../remote/azure/azure_vm_services_test.py | 202 +++++---- .../tests/services/remote/azure/conftest.py | 95 +++-- .../services/remote/mock/mock_auth_service.py | 26 +- .../remote/mock/mock_fileshare_service.py | 25 +- .../remote/mock/mock_network_service.py | 35 +- .../remote/mock/mock_remote_exec_service.py | 26 +- .../services/remote/mock/mock_vm_service.py | 55 ++- .../tests/services/remote/ssh/__init__.py | 14 +- .../tests/services/remote/ssh/fixtures.py | 63 ++- .../services/remote/ssh/test_ssh_fileshare.py | 43 +- .../remote/ssh/test_ssh_host_service.py | 94 +++-- .../services/remote/ssh/test_ssh_service.py | 53 ++- .../mlos_bench/tests/storage/conftest.py | 4 +- .../mlos_bench/tests/storage/exp_data_test.py | 67 +-- .../mlos_bench/tests/storage/exp_load_test.py | 62 +-- .../mlos_bench/tests/storage/sql/fixtures.py | 85 ++-- .../tests/storage/trial_config_test.py | 10 +- .../tests/storage/trial_schedule_test.py | 22 +- .../tests/storage/trial_telemetry_test.py | 41 +- .../tests/storage/tunable_config_data_test.py | 21 +- .../tunable_config_trial_group_data_test.py | 38 +- .../mlos_bench/tests/test_with_alt_tz.py | 6 +- .../tests/tunable_groups_fixtures.py | 38 +- .../mlos_bench/tests/tunables/conftest.py | 47 ++- .../tunables/test_tunable_categoricals.py | 2 +- .../tunables/test_tunables_size_props.py | 27 +- .../tests/tunables/tunable_comparison_test.py | 15 +- .../tests/tunables/tunable_definition_test.py | 98 +++-- .../tunables/tunable_distributions_test.py | 68 ++- .../tunables/tunable_group_indexing_test.py | 4 +- .../tunables/tunable_group_subgroup_test.py | 2 +- .../tunable_to_configspace_distr_test.py | 54 +-- .../tunables/tunable_to_configspace_test.py | 59 ++- .../tests/tunables/tunables_assign_test.py | 26 +- .../tests/tunables/tunables_str_test.py | 76 ++-- mlos_bench/mlos_bench/tunables/__init__.py | 6 +- .../mlos_bench/tunables/covariant_group.py | 18 +- mlos_bench/mlos_bench/tunables/tunable.py | 57 ++- .../mlos_bench/tunables/tunable_groups.py | 58 ++- mlos_bench/mlos_bench/util.py | 42 +- mlos_bench/mlos_bench/version.py | 2 +- mlos_bench/setup.py | 81 ++-- mlos_core/mlos_core/optimizers/__init__.py | 32 +- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 +- .../bayesian_optimizers/smac_optimizer.py | 134 +++--- .../mlos_core/optimizers/flaml_optimizer.py | 57 ++- mlos_core/mlos_core/optimizers/optimizer.py | 117 ++++-- .../mlos_core/optimizers/random_optimizer.py | 30 +- .../mlos_core/spaces/adapters/__init__.py | 19 +- .../mlos_core/spaces/adapters/adapter.py | 6 +- .../mlos_core/spaces/adapters/llamatune.py | 164 +++++--- .../mlos_core/spaces/converters/flaml.py | 18 +- mlos_core/mlos_core/tests/__init__.py | 19 +- .../optimizers/bayesian_optimizers_test.py | 23 +- .../mlos_core/tests/optimizers/conftest.py | 6 +- .../tests/optimizers/one_hot_test.py | 77 ++-- .../optimizers/optimizer_multiobj_test.py | 78 ++-- .../tests/optimizers/optimizer_test.py | 213 ++++++---- .../spaces/adapters/identity_adapter_test.py | 25 +- .../tests/spaces/adapters/llamatune_test.py | 393 +++++++++++------- .../adapters/space_adapter_factory_test.py | 56 ++- .../mlos_core/tests/spaces/spaces_test.py | 39 +- mlos_core/mlos_core/util.py | 10 +- mlos_core/mlos_core/version.py | 2 +- mlos_core/setup.py | 57 +-- mlos_viz/mlos_viz/__init__.py | 19 +- mlos_viz/mlos_viz/base.py | 208 +++++---- mlos_viz/mlos_viz/dabl.py | 62 ++- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 7 +- mlos_viz/mlos_viz/version.py | 2 +- mlos_viz/setup.py | 47 ++- 210 files changed, 6260 insertions(+), 4082 deletions(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index c32dea9bf6..2c6da8cc6a 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -20,7 +20,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: Flatten every dict in the hierarchy and rename the keys with the dict path. """ if isinstance(data, dict): - for (key, val) in data.items(): + for key, val in data.items(): yield from _flat_dict(val, f"{path}.{key}") else: yield (path, data) @@ -30,13 +30,15 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: """ Convert FIO read data from JSON to tall CSV. """ - with open(input_file, mode='r', encoding='utf-8') as fh_input: + with open(input_file, mode="r", encoding="utf-8") as fh_input: json_data = json.load(fh_input) - data = list(itertools.chain( - _flat_dict(json_data["jobs"][0], prefix), - _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util") - )) + data = list( + itertools.chain( + _flat_dict(json_data["jobs"][0], prefix), + _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"), + ) + ) tall_df = pandas.DataFrame(data, columns=["metric", "value"]) tall_df.to_csv(output_file, index=False) @@ -49,12 +51,12 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.") parser.add_argument( - "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).") + "input", help="FIO benchmark results in JSON format (downloaded from a remote VM)." + ) parser.add_argument( - "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).") - parser.add_argument( - "--prefix", default="fio", - help="Prefix of the metric IDs (default 'fio')") + "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench)." + ) + parser.add_argument("--prefix", default="fio", help="Prefix of the metric IDs (default 'fio')") args = parser.parse_args() _main(args.input, args.output, args.prefix) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index 949b9f9d91..d41f20d2a9 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -14,17 +14,19 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): - line = f'{key} {val}' + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): + line = f"{key} {val}" fh_config.write(line + "\n") print(line) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate Redis config from tunable parameters JSON.") + description="generate Redis config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output Redis config file.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index e33c717953..81a2b673a4 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -21,18 +21,19 @@ def _main(input_file: str, output_file: str) -> None: # Format the results from wide to long # The target is columns of metric and value to act as key-value pairs. df_long = ( - df_wide - .melt(id_vars=["test"]) + df_wide.melt(id_vars=["test"]) .assign(metric=lambda df: df["test"] + "_" + df["variable"]) .drop(columns=["test", "variable"]) .loc[:, ["metric", "value"]] ) # Add a default `score` metric to the end of the dataframe. - df_long = pd.concat([ - df_long, - pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}) - ]) + df_long = pd.concat( + [ + df_long, + pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}), + ] + ) df_long.to_csv(output_file, index=False) print(f"Converted: {input_file} -> {output_file}") @@ -42,7 +43,9 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") - parser.add_argument("output", help="Converted Redis benchmark data" + - " (to be consumed by OS Autotune framework).") + parser.add_argument( + "output", + help="Converted Redis benchmark data" + " (to be consumed by OS Autotune framework).", + ) args = parser.parse_args() _main(args.input, args.output) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 41bd162459..40a05e1511 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -14,8 +14,10 @@ JSON_CONFIG_FILE = "config-boot-time.json" NEW_CFG = "zz-mlos-boot-params.cfg" -with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \ - open(NEW_CFG, 'w', encoding='UTF-8') as fh_config: +with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open( + NEW_CFG, "w", encoding="UTF-8" +) as fh_config: for key, val in json.load(fh_json).items(): - fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$' - f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n') + fh_config.write( + 'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n' + ) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index de344d61fb..9f130e5c0e 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -14,9 +14,10 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"' fh_config.write(line + "\n") print(line) @@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate GRUB config from tunable parameters JSON.") + description="Generate GRUB config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output shell script to configure GRUB.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index 85a49a1817..e632495061 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -22,7 +22,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: tunables_meta = json.load(fh_meta) with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in tunables_data.items(): + for key, val in tunables_data.items(): meta = tunables_meta.get(key, {}) name_prefix = meta.get("name_prefix", "") line = f'echo "{val}" > {name_prefix}{key}' @@ -33,7 +33,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate a script to update kernel parameters from tunables JSON.") + description="generate a script to update kernel parameters from tunables JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("meta", help="JSON file with tunable parameters metadata.") diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index fa3b63e2e6..672a215aad 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -9,6 +9,6 @@ from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ - 'ConfigSchema', - 'CONFIG_SCHEMA_DIR', + "ConfigSchema", + "CONFIG_SCHEMA_DIR", ] diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 82cbcacce2..181f96e5d6 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -27,9 +27,14 @@ # It is used in `ConfigSchema.validate()` method below. # NOTE: this may cause pytest to fail if it's expecting exceptions # to be raised for invalid configs. -_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION' -_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower() - in {'true', 'y', 'yes', 'on', '1'}) +_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION" +_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in { + "true", + "y", + "yes", + "on", + "1", +} # Note: we separate out the SchemaStore from a class method on ConfigSchema @@ -80,10 +85,12 @@ def _load_registry(cls) -> None: """Also store them in a Registry object for referencing by recent versions of jsonschema.""" if not cls._SCHEMA_STORE: cls._load_schemas() - cls._REGISTRY = Registry().with_resources([ - (url, Resource.from_contents(schema, default_specification=DRAFT202012)) - for url, schema in cls._SCHEMA_STORE.items() - ]) + cls._REGISTRY = Registry().with_resources( + [ + (url, Resource.from_contents(schema, default_specification=DRAFT202012)) + for url, schema in cls._SCHEMA_STORE.items() + ] + ) @property def registry(self) -> Registry: diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 4ccef7817b..3c14e63598 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -13,7 +13,7 @@ from mlos_bench.os_environ import environ -class DictTemplater: # pylint: disable=too-few-public-methods +class DictTemplater: # pylint: disable=too-few-public-methods """ Simple class to help with nested dictionary $var templating. """ @@ -32,9 +32,9 @@ def __init__(self, source_dict: Dict[str, Any]): # The source/target dictionary to expand. self._dict: Dict[str, Any] = {} - def expand_vars(self, *, - extra_source_dict: Optional[Dict[str, Any]] = None, - use_os_env: bool = False) -> Dict[str, Any]: + def expand_vars( + self, *, extra_source_dict: Optional[Dict[str, Any]] = None, use_os_env: bool = False + ) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. @@ -55,7 +55,9 @@ def expand_vars(self, *, assert isinstance(self._dict, dict) return self._dict - def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: + def _expand_vars( + self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool + ) -> Any: """ Recursively expand $var strings in the currently operating dictionary. """ @@ -71,7 +73,7 @@ def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], elif isinstance(value, dict): # Note: we use a loop instead of dict comprehension in order to # allow secondary expansion of subsequent values immediately. - for (key, val) in value.items(): + for key, val in value.items(): value[key] = self._expand_vars(val, extra_source_dict, use_os_env) elif isinstance(value, list): value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index a1ccadae5f..629e7d9c5f 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -15,12 +15,11 @@ from mlos_bench.environments.status import Status __all__ = [ - 'Status', - - 'Environment', - 'MockEnv', - 'RemoteEnv', - 'LocalEnv', - 'LocalFileShareEnv', - 'CompositeEnv', + "Status", + "Environment", + "MockEnv", + "RemoteEnv", + "LocalEnv", + "LocalFileShareEnv", + "CompositeEnv", ] diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 61fbd69f50..f1ec25823c 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -48,15 +48,16 @@ class Environment(metaclass=abc.ABCMeta): """ @classmethod - def new(cls, - *, - env_name: str, - class_name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ) -> "Environment": + def new( + cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -94,16 +95,18 @@ def new(cls, config=config, global_config=global_config, tunables=tunables, - service=service + service=service, ) - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -134,24 +137,29 @@ def __init__(self, self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Environment: '%s' Service: %s", name, - self._service.pprint() if self._service else None) + _LOG.debug( + "Environment: '%s' Service: %s", + name, + self._service.pprint() if self._service else None, + ) if tunables is None: - _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) + _LOG.warning( + "No tunables provided for %s. Tunable inheritance across composite environments may be broken.", + name, + ) tunables = TunableGroups() groups = self._expand_groups( - config.get("tunable_params", []), - (global_config or {}).get("tunable_params_map", {})) + config.get("tunable_params", []), (global_config or {}).get("tunable_params_map", {}) + ) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) self._tunable_params = tunables.subgroup(groups) # If a parameter comes from the tunables, do not require it in the const_args or globals - req_args = ( - set(config.get("required_args", [])) - - set(self._tunable_params.get_param_values().keys()) + req_args = set(config.get("required_args", [])) - set( + self._tunable_params.get_param_values().keys() ) merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) self._const_args = self._expand_vars(self._const_args, global_config or {}) @@ -160,8 +168,7 @@ def __init__(self, _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", - name, json.dumps(self.config, indent=2)) + _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: """ @@ -179,8 +186,9 @@ def _validate_json_config(self, config: dict, name: str) -> None: ConfigSchema.ENVIRONMENT.validate(json_config) @staticmethod - def _expand_groups(groups: Iterable[str], - groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: + def _expand_groups( + groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] + ) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -202,7 +210,9 @@ def _expand_groups(groups: Iterable[str], if grp[:1] == "$": tunable_group_name = grp[1:] if tunable_group_name not in groups_exp: - raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") + raise KeyError( + f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}" + ) add_groups = groups_exp[tunable_group_name] res += [add_groups] if isinstance(add_groups, str) else add_groups else: @@ -210,7 +220,9 @@ def _expand_groups(groups: Iterable[str], return res @staticmethod - def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: + def _expand_vars( + params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] + ) -> dict: """ Expand `$var` into actual values of the variables. """ @@ -221,7 +233,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": assert self._service is not None return self._service.config_loader_service - def __enter__(self) -> 'Environment': + def __enter__(self) -> "Environment": """ Enter the environment's benchmarking context. """ @@ -232,9 +244,12 @@ def __enter__(self) -> 'Environment': self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -304,7 +319,8 @@ def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ return tunables.get_param_values( group_names=list(self._tunable_params.get_covariant_group_names()), - into_params=self._const_args.copy()) + into_params=self._const_args.copy(), + ) @property def tunable_params(self) -> TunableGroups: @@ -364,10 +380,15 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - # (Derived classes still have to check `self._tunable_params.is_updated()`). is_updated = self._tunable_params.is_updated() if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { - name: self._tunable_params.is_updated([name]) - for name in self._tunable_params.get_covariant_group_names() - }) + _LOG.debug( + "Env '%s': Tunable groups reset = %s :: %s", + self, + is_updated, + { + name: self._tunable_params.is_updated([name]) + for name in self._tunable_params.get_covariant_group_names() + }, + ) else: _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index a71b8ab9be..36ab99a223 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -27,13 +27,15 @@ class CompositeEnv(Environment): Composite benchmark environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -53,8 +55,13 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # By default, the Environment includes only the tunables explicitly specified # in the "tunable_params" section of the config. `CompositeEnv`, however, must @@ -70,17 +77,19 @@ def __init__(self, # each CompositeEnv gets a copy of the original global config and adjusts it with # the `const_args` specific to it. global_config = (global_config or {}).copy() - for (key, val) in self._const_args.items(): + for key, val in self._const_args.items(): global_config.setdefault(key, val) for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, tunables, global_config, self._const_args, self._service): + child_config_file, tunables, global_config, self._const_args, self._service + ): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service) + child_config, tunables, global_config, self._const_args, self._service + ) self._add_child(env, tunables) _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) @@ -92,9 +101,12 @@ def __enter__(self) -> Environment: self._child_contexts = [env.__enter__() for env in self._children] return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: ex_throw = None for env in reversed(self._children): try: @@ -132,8 +144,11 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: pretty : str Pretty-printed environment configuration. """ - return super().pprint(indent, level) + '\n' + '\n'.join( - child.pprint(indent, level + 1) for child in self._children) + return ( + super().pprint(indent, level) + + "\n" + + "\n".join(child.pprint(indent, level + 1) for child in self._children) + ) def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ @@ -165,7 +180,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - """ assert self._in_context self._is_ready = super().setup(tunables, global_config) and all( - env_context.setup(tunables, global_config) for env_context in self._child_contexts) + env_context.setup(tunables, global_config) for env_context in self._child_contexts + ) return self._is_ready def teardown(self) -> None: diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index 0cdd8349b4..a99eefea19 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv __all__ = [ - 'LocalEnv', - 'LocalFileShareEnv', + "LocalEnv", + "LocalFileShareEnv", ] diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index da20f5c961..72616f7cd3 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -36,13 +36,15 @@ class LocalEnv(ScriptEnv): Scheduler-side Environment that runs scripts locally. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for local execution. @@ -65,11 +67,17 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service self._temp_dir: Optional[str] = None @@ -83,13 +91,18 @@ def __init__(self, def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None - self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir")) + self._temp_dir_context = self._local_exec_service.temp_dir_context( + self.config.get("temp_dir") + ) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -137,10 +150,14 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - fname = path_join(self._temp_dir, self._dump_meta_file) _LOG.debug("Dump tunables metadata to file: %s", fname) with open(fname, "wt", encoding="utf-8") as fh_meta: - json.dump({ - tunable.name: tunable.meta - for (tunable, _group) in self._tunable_params if tunable.meta - }, fh_meta) + json.dump( + { + tunable.name: tunable.meta + for (tunable, _group) in self._tunable_params + if tunable.meta + }, + fh_meta, + ) if self._script_setup: (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir) @@ -180,18 +197,24 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: _LOG.debug("Not reading the data at: %s", self) return (Status.SUCCEEDED, timestamp, stdout_data) - data = self._normalize_columns(pandas.read_csv( - self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir]), - index_col=False, - )) + data = self._normalize_columns( + pandas.read_csv( + self._config_loader_service.resolve_path( + self._read_results_file, extra_paths=[self._temp_dir] + ), + index_col=False, + ) + ) _LOG.debug("Read data:\n%s", data) if list(data.columns) == ["metric", "value"]: - _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data)) + _LOG.info( + "Local results have (metric,value) header and %d rows: assume long format", + len(data), + ) data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) # Try to convert string metrics to numbers. - data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive) + data = data.apply(pandas.to_numeric, errors="coerce").fillna(data) # type: ignore[assignment] # (false positive) elif len(data) == 1: _LOG.info("Local results have 1 row: assume wide format") else: @@ -209,8 +232,8 @@ def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. - if sys.platform == 'win32': - data.rename(str.rstrip, axis='columns', inplace=True) + if sys.platform == "win32": + data.rename(str.rstrip, axis="columns", inplace=True) return data def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: @@ -222,24 +245,23 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir]) + self._read_telemetry_file, extra_paths=[self._temp_dir] + ) # TODO: Use the timestamp of the CSV file as our status timestamp? # FIXME: We should not be assuming that the only output file type is a CSV. - data = self._normalize_columns( - pandas.read_csv(fname, index_col=False)) + data = self._normalize_columns(pandas.read_csv(fname, index_col=False)) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") expected_col_names = ["timestamp", "metric", "value"] if len(data.columns) != len(expected_col_names): - raise ValueError(f'Telemetry data must have columns {expected_col_names}') + raise ValueError(f"Telemetry data must have columns {expected_col_names}") if list(data.columns) != expected_col_names: # Assume no header - this is ok for telemetry data. - data = pandas.read_csv( - fname, index_col=False, names=expected_col_names) + data = pandas.read_csv(fname, index_col=False, names=expected_col_names) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") except FileNotFoundError as ex: @@ -248,10 +270,14 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: _LOG.debug("Read telemetry data:\n%s", data) col_dtypes: Mapping[int, Type] = {0: datetime} - return (status, timestamp, [ - (pandas.Timestamp(ts).to_pydatetime(), metric, value) - for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) - ]) + return ( + status, + timestamp, + [ + (pandas.Timestamp(ts).to_pydatetime(), metric, value) + for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) + ], + ) def teardown(self) -> None: """ @@ -283,7 +309,8 @@ def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd) + script, env=env_params, cwd=cwd + ) if return_code != 0: _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) return (return_code, {"stdout": stdout, "stderr": stderr}) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 174afd387c..7a6862ab2c 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -29,13 +29,15 @@ class LocalFileShareEnv(LocalEnv): and uploads/downloads data to the shared file storage. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new application environment with a given config. @@ -59,14 +61,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service - assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ - "LocalEnv requires a service that supports file upload/download operations" + assert self._service is not None and isinstance( + self._service, SupportsFileShareOps + ), "LocalEnv requires a service that supports file upload/download operations" self._file_share_service: SupportsFileShareOps = self._service self._upload = self._template_from_to("upload") @@ -77,14 +87,12 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: Convert a list of {"from": "...", "to": "..."} to a list of pairs of string.Template objects so that we can plug in self._params into it later. """ - return [ - (Template(d['from']), Template(d['to'])) - for d in self.config.get(config_key, []) - ] + return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])] @staticmethod - def _expand(from_to: Iterable[Tuple[Template, Template]], - params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: + def _expand( + from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] + ) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. Return a generator of (str, str) pairs of paths. @@ -119,9 +127,14 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._upload, params): - self._file_share_service.upload(self._params, self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir]), path_to) + for path_from, path_to in self._expand(self._upload, params): + self._file_share_service.upload( + self._params, + self._config_loader_service.resolve_path( + path_from, extra_paths=[self._temp_dir] + ), + path_to, + ) return self._is_ready def _download_files(self, ignore_missing: bool = False) -> None: @@ -137,11 +150,15 @@ def _download_files(self, ignore_missing: bool = False) -> None: assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._download, params): + for path_from, path_to in self._expand(self._download, params): try: - self._file_share_service.download(self._params, - path_from, self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir])) + self._file_share_service.download( + self._params, + path_from, + self._config_loader_service.resolve_path( + path_to, extra_paths=[self._temp_dir] + ), + ) except FileNotFoundError as ex: _LOG.warning("Cannot download: %s", path_from) if not ignore_missing: diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index cc47b95500..2f4d4b0ab4 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -29,13 +29,15 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment that produces mock benchmark data. @@ -55,8 +57,13 @@ def __init__(self, service: Service An optional service object. Not used by this class. """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) seed = int(self.config.get("mock_env_seed", -1)) self._random = random.Random(seed or None) if seed >= 0 else None self._range = self.config.get("mock_env_range") @@ -81,9 +88,9 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return result # Simple convex function of all tunable parameters. - score = numpy.mean(numpy.square([ - self._normalized(tunable) for (tunable, _group) in self._tunable_params - ])) + score = numpy.mean( + numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params]) + ) # Add noise and shift the benchmark value from [0, 1] to a given range. noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 @@ -101,11 +108,11 @@ def _normalized(tunable: Tunable) -> float: """ val = None if tunable.is_categorical: - val = (tunable.categories.index(tunable.category) / - float(len(tunable.categories) - 1)) + val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1) elif tunable.is_numerical: - val = ((tunable.numerical_value - tunable.range[0]) / - float(tunable.range[1] - tunable.range[0])) + val = (tunable.numerical_value - tunable.range[0]) / float( + tunable.range[1] - tunable.range[0] + ) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index f07575ac86..be18bff2fe 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -14,10 +14,10 @@ from mlos_bench.environments.remote.vm_env import VMEnv __all__ = [ - 'HostEnv', - 'NetworkEnv', - 'OSEnv', - 'RemoteEnv', - 'SaaSEnv', - 'VMEnv', + "HostEnv", + "NetworkEnv", + "OSEnv", + "RemoteEnv", + "SaaSEnv", + "VMEnv", ] diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index 05896c9e60..3b1abcd79a 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -22,13 +22,15 @@ class HostEnv(Environment): Remote host environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for host operations. @@ -49,10 +51,17 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM/host, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \ - "HostEnv requires a service that supports host provisioning operations" + assert self._service is not None and isinstance( + self._service, SupportsHostProvisioning + ), "HostEnv requires a service that supports host provisioning operations" self._host_service: SupportsHostProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index 552f1729d9..afa38229f5 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -27,13 +27,15 @@ class NetworkEnv(Environment): but no real tuning is expected for it ... yet. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for network operations. @@ -54,14 +56,21 @@ def __init__(self, An optional service object (e.g., providing methods to deploy a network, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # Virtual networks can be used for more than one experiment, so by default # we don't attempt to deprovision them. self._deprovision_on_teardown = config.get("deprovision_on_teardown", False) - assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \ - "NetworkEnv requires a service that supports network provisioning" + assert self._service is not None and isinstance( + self._service, SupportsNetworkProvisioning + ), "NetworkEnv requires a service that supports network provisioning" self._network_service: SupportsNetworkProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -105,7 +114,9 @@ def teardown(self) -> None: return # Else _LOG.info("Network tear down: %s", self) - (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True) + (status, params) = self._network_service.deprovision_network( + self._params, ignore_errors=True + ) if status.is_pending(): (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index ef733c77c2..9fa2b5886a 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -24,13 +24,15 @@ class OSEnv(Environment): OS Level Environment for a host. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -53,14 +55,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsOSOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsOSOps + ), "RemoteEnv requires a service that supports host operations" self._os_service: SupportsOSOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index cf38a57b01..683405c6c5 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -61,18 +63,25 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a Host, VM, OS, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._wait_boot = self.config.get("wait_boot", False) - assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ - "RemoteEnv requires a service that supports remote execution operations" + assert self._service is not None and isinstance( + self._service, SupportsRemoteExec + ), "RemoteEnv requires a service that supports remote execution operations" self._remote_exec_service: SupportsRemoteExec = self._service if self._wait_boot: - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -170,7 +179,8 @@ def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optiona env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params) + script, config=self._params, env_params=env_params + ) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: (status, output) = self._remote_exec_service.get_remote_exec_results(output) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index b661bfad7e..211db536d0 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -23,13 +23,15 @@ class SaaSEnv(Environment): Cloud-based (configurable) SaaS environment. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for (configurable) cloud-based SaaS instance. @@ -50,15 +52,22 @@ def __init__(self, An optional service object (e.g., providing methods to configure the remote service). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \ - "SaaSEnv requires a service that supports remote host configuration API" + assert self._service is not None and isinstance( + self._service, SupportsRemoteConfig + ), "SaaSEnv requires a service that supports remote host configuration API" self._config_service: SupportsRemoteConfig = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -84,7 +93,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values()) + self._params, self._tunable_params.get_param_values() + ) if not status.is_succeeded(): return False @@ -93,7 +103,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False # Azure Flex DB instances currently require a VM reboot after reconfiguration. - if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'): + if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"): _LOG.info("Restarting: %s", self) (status, params) = self._host_service.restart_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 129ac21a0f..d65d137459 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -27,13 +27,15 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for script execution. @@ -63,19 +65,29 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) - self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) + self._shell_env_params_rename: Dict[str, str] = self.config.get( + "shell_env_params_rename", {} + ) results_stdout_pattern = self.config.get("results_stdout_pattern") - self._results_stdout_pattern: Optional[re.Pattern[str]] = \ - re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None + self._results_stdout_pattern: Optional[re.Pattern[str]] = ( + re.compile(results_stdout_pattern, flags=re.MULTILINE) + if results_stdout_pattern + else None + ) def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: """ @@ -116,4 +128,6 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: if not self._results_stdout_pattern: return {} _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) - return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)} + return { + key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout) + } diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index 4555ab7f50..f69893871f 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -20,7 +20,7 @@ else: from typing_extensions import TypeAlias -CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name +CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name if sys.version_info >= (3, 9): FutureReturnType: TypeAlias = Future[CoroReturnType] else: diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c8e48dab69..d988e370b3 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -32,7 +32,7 @@ from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO -_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' +_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) _LOG = logging.getLogger(__name__) @@ -54,8 +54,7 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st For additional details, please see the website or the README.md files in the source tree: """ - parser = argparse.ArgumentParser(description=f"{description} : {long_text}", - epilog=epilog) + parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) (args, args_rest) = self._parse_args(parser, argv) # Bootstrap config loader: command line takes priority. @@ -96,13 +95,13 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: - self.global_config['experiment_id'] = args.experiment_id + self.global_config["experiment_id"] = args.experiment_id # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. - self.global_config.setdefault('trial_id', 1) + self.global_config.setdefault("trial_id", 1) self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) assert isinstance(self.global_config, dict) @@ -110,24 +109,29 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # --service cli args should override the config file values. service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) - self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) + self._parent_service = self._parent_service.load_services( + service_files, self.global_config, self._parent_service + ) env_path = args.environment or config.get("environment") if not env_path: _LOG.error("No environment config specified.") - parser.error("At least the Environment config must be specified." + - " Run `mlos_bench --help` and consult `README.md` for more info.") + parser.error( + "At least the Environment config must be specified." + + " Run `mlos_bench --help` and consult `README.md` for more info." + ) self.root_env_config = self._config_loader.resolve_path(env_path) self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) + self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service + ) _LOG.info("Init environment: %s", self.environment) # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer self.tunables = self._init_tunable_values( args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, - config.get("tunable_values", []) + (args.tunable_values or []) + config.get("tunable_values", []) + (args.tunable_values or []), ) _LOG.info("Init tunables: %s", self.tunables) @@ -137,7 +141,11 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st self.storage = self._load_storage(args.storage or config.get("storage")) _LOG.info("Init storage: %s", self.storage) - self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) + self.teardown: bool = ( + bool(args.teardown) + if args.teardown is not None + else bool(config.get("teardown", True)) + ) self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) _LOG.info("Init scheduler: %s", self.scheduler) @@ -156,87 +164,146 @@ def service(self) -> Service: return self._parent_service @staticmethod - def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: + def _parse_args( + parser: argparse.ArgumentParser, argv: Optional[List[str]] + ) -> Tuple[argparse.Namespace, List[str]]: """ Parse the command line arguments. """ parser.add_argument( - '--config', required=False, - help='Main JSON5 configuration file. Its keys are the same as the' + - ' command line options and can be overridden by the latter.\n' + - '\n' + - ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + - ' for additional config examples for this and other arguments.') + "--config", + required=False, + help="Main JSON5 configuration file. Its keys are the same as the" + + " command line options and can be overridden by the latter.\n" + + "\n" + + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " + + " for additional config examples for this and other arguments.", + ) parser.add_argument( - '--log_file', '--log-file', required=False, - help='Path to the log file. Use stdout if omitted.') + "--log_file", + "--log-file", + required=False, + help="Path to the log file. Use stdout if omitted.", + ) parser.add_argument( - '--log_level', '--log-level', required=False, type=str, - help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + - ' Set to DEBUG for debug, WARNING for warnings only.') + "--log_level", + "--log-level", + required=False, + type=str, + help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}." + + " Set to DEBUG for debug, WARNING for warnings only.", + ) parser.add_argument( - '--config_path', '--config-path', '--config-paths', '--config_paths', - nargs="+", action='extend', required=False, - help='One or more locations of JSON config files.') + "--config_path", + "--config-path", + "--config-paths", + "--config_paths", + nargs="+", + action="extend", + required=False, + help="One or more locations of JSON config files.", + ) parser.add_argument( - '--service', '--services', - nargs='+', action='extend', required=False, - help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') + "--service", + "--services", + nargs="+", + action="extend", + required=False, + help="Path to JSON file with the configuration of the service(s) for environment(s) to use.", + ) parser.add_argument( - '--environment', required=False, - help='Path to JSON file with the configuration of the benchmarking environment(s).') + "--environment", + required=False, + help="Path to JSON file with the configuration of the benchmarking environment(s).", + ) parser.add_argument( - '--optimizer', required=False, - help='Path to the optimizer configuration file. If omitted, run' + - ' a single trial with default (or specified in --tunable_values).') + "--optimizer", + required=False, + help="Path to the optimizer configuration file. If omitted, run" + + " a single trial with default (or specified in --tunable_values).", + ) parser.add_argument( - '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, - help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') + "--trial_config_repeat_count", + "--trial-config-repeat-count", + required=False, + type=int, + help="Number of times to repeat each config. Default is 1 trial per config, though more may be advised.", + ) parser.add_argument( - '--scheduler', required=False, - help='Path to the scheduler configuration file. By default, use' + - ' a single worker synchronous scheduler.') + "--scheduler", + required=False, + help="Path to the scheduler configuration file. By default, use" + + " a single worker synchronous scheduler.", + ) parser.add_argument( - '--storage', required=False, - help='Path to the storage configuration file.' + - ' If omitted, use the ephemeral in-memory SQL storage.') + "--storage", + required=False, + help="Path to the storage configuration file." + + " If omitted, use the ephemeral in-memory SQL storage.", + ) parser.add_argument( - '--random_init', '--random-init', required=False, default=False, - dest='random_init', action='store_true', - help='Initialize tunables with random values. (Before applying --tunable_values).') + "--random_init", + "--random-init", + required=False, + default=False, + dest="random_init", + action="store_true", + help="Initialize tunables with random values. (Before applying --tunable_values).", + ) parser.add_argument( - '--random_seed', '--random-seed', required=False, type=int, - help='Seed to use with --random_init') + "--random_seed", + "--random-seed", + required=False, + type=int, + help="Seed to use with --random_init", + ) parser.add_argument( - '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain values of the tunable' + - ' parameters. This can be used for a single trial (when no --optimizer' + - ' is specified) or as default values for the first run in optimization.') + "--tunable_values", + "--tunable-values", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain values of the tunable" + + " parameters. This can be used for a single trial (when no --optimizer" + + " is specified) or as default values for the first run in optimization.", + ) parser.add_argument( - '--globals', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain additional' + - ' [private] parameters of the benchmarking environment.') + "--globals", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain additional" + + " [private] parameters of the benchmarking environment.", + ) parser.add_argument( - '--no_teardown', '--no-teardown', required=False, default=None, - dest='teardown', action='store_false', - help='Disable teardown of the environment after the benchmark.') + "--no_teardown", + "--no-teardown", + required=False, + default=None, + dest="teardown", + action="store_false", + help="Disable teardown of the environment after the benchmark.", + ) parser.add_argument( - '--experiment_id', '--experiment-id', required=False, default=None, + "--experiment_id", + "--experiment-id", + required=False, + default=None, help=""" Experiment ID to use for the benchmark. If omitted, the value from the --cli config or --globals is used. @@ -246,7 +313,7 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T changes are made to config files, scripts, versions, etc. This is left as a manual operation as detection of what is "incompatible" is not easily automatable across systems. - """ + """, ) # By default we use the command line arguments, but allow the caller to @@ -288,16 +355,18 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: _LOG.debug("Parsed config: %s", config) return config - def _load_config(self, - args_globals: Iterable[str], - config_path: Iterable[str], - args_rest: Iterable[str], - global_config: Dict[str, Any]) -> Dict[str, Any]: + def _load_config( + self, + args_globals: Iterable[str], + config_path: Iterable[str], + args_rest: Iterable[str], + global_config: Dict[str, Any], + ) -> Dict[str, Any]: """ Get key/value pairs of the global configuration parameters from the specified config files (if any) and command line arguments. """ - for config_file in (args_globals or []): + for config_file in args_globals or []: conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) assert isinstance(conf, dict) global_config.update(conf) @@ -306,8 +375,9 @@ def _load_config(self, global_config["config_path"] = config_path return global_config - def _init_tunable_values(self, random_init: bool, seed: Optional[int], - args_tunables: Optional[str]) -> TunableGroups: + def _init_tunable_values( + self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] + ) -> TunableGroups: """ Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. @@ -317,8 +387,10 @@ def _init_tunable_values(self, random_init: bool, seed: Optional[int], if random_init: tunables = MockOptimizer( - tunables=tunables, service=None, - config={"start_with_defaults": False, "seed": seed}).suggest() + tunables=tunables, + service=None, + config={"start_with_defaults": False, "seed": seed}, + ).suggest() _LOG.debug("Init tunables: random = %s", tunables) if args_tunables is not None: @@ -339,15 +411,20 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: if args_optimizer is None: # global_config may contain additional properties, so we need to # strip those out before instantiating the basic oneshot optimizer. - config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} - return OneShotOptimizer( - self.tunables, config=config, service=self._parent_service) + config = { + key: val + for key, val in self.global_config.items() + if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS + } + return OneShotOptimizer(self.tunables, config=config, service=self._parent_service) class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) assert isinstance(class_config, Dict) - optimizer = self._config_loader.build_optimizer(tunables=self.tunables, - service=self._parent_service, - config=class_config, - global_config=self.global_config) + optimizer = self._config_loader.build_optimizer( + tunables=self.tunables, + service=self._parent_service, + config=class_config, + global_config=self.global_config, + ) return optimizer def _load_storage(self, args_storage: Optional[str]) -> Storage: @@ -359,17 +436,20 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: if args_storage is None: # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.storage import SqlStorage - return SqlStorage(service=self._parent_service, - config={ - "drivername": "sqlite", - "database": ":memory:", - "lazy_schema_create": True, - }) + + return SqlStorage( + service=self._parent_service, + config={ + "drivername": "sqlite", + "database": ":memory:", + "lazy_schema_create": True, + }, + ) class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) assert isinstance(class_config, Dict) - storage = self._config_loader.build_storage(service=self._parent_service, - config=class_config, - global_config=self.global_config) + storage = self._config_loader.build_storage( + service=self._parent_service, config=class_config, global_config=self.global_config + ) return storage def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: @@ -384,6 +464,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: if args_scheduler is None: # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.sync_scheduler import SyncScheduler + return SyncScheduler( # All config values can be overridden from global config config={ diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index f10fa3c82e..a61b55d440 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -12,8 +12,8 @@ from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer __all__ = [ - 'Optimizer', - 'MockOptimizer', - 'OneShotOptimizer', - 'MlosCoreOptimizer', + "Optimizer", + "MockOptimizer", + "OneShotOptimizer", + "MlosCoreOptimizer", ] diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index b9df1db1b7..f719c236e5 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -26,7 +26,7 @@ _LOG = logging.getLogger(__name__) -class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes +class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ An abstract interface between the benchmarking framework and mlos_core optimizers. """ @@ -39,11 +39,13 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr "start_with_defaults", } - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new optimizer for the given configuration space defined by the tunables. @@ -67,19 +69,20 @@ def __init__(self, self._seed = int(config.get("seed", 42)) self._in_context = False - experiment_id = self._global_config.get('experiment_id') + experiment_id = self._global_config.get("experiment_id") self.experiment_id = str(experiment_id).strip() if experiment_id else None self._iter = 0 # If False, use the optimizer to suggest the initial configuration; # if True (default), use the already initialized values for the first iteration. self._start_with_defaults: bool = bool( - strtobool(str(self._config.pop('start_with_defaults', True)))) - self._max_iter = int(self._config.pop('max_suggestions', 100)) + strtobool(str(self._config.pop("start_with_defaults", True))) + ) + self._max_iter = int(self._config.pop("max_suggestions", 100)) - opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'}) + opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"}) self._opt_targets: Dict[str, Literal[1, -1]] = {} - for (opt_target, opt_dir) in opt_targets.items(): + for opt_target, opt_dir in opt_targets.items(): if opt_dir == "min": self._opt_targets[opt_target] = 1 elif opt_dir == "max": @@ -107,7 +110,7 @@ def __repr__(self) -> str: ) return f"{self.name}({opt_targets},config={self._config})" - def __enter__(self) -> 'Optimizer': + def __enter__(self) -> "Optimizer": """ Enter the optimizer's context. """ @@ -116,9 +119,12 @@ def __enter__(self) -> 'Optimizer': self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the optimizer. """ @@ -203,7 +209,7 @@ def name(self) -> str: return self.__class__.__name__ @property - def targets(self) -> Dict[str, Literal['min', 'max']]: + def targets(self) -> Dict[str, Literal["min", "max"]]: """ A dictionary of {target: direction} of optimization targets. """ @@ -220,10 +226,12 @@ def supports_preload(self) -> bool: return True @abstractmethod - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. @@ -241,8 +249,12 @@ def bulk_register(self, is_not_empty : bool True if there is data to register, false otherwise. """ - _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values", - len(configs or []), len(scores or []), len(status or [])) + _LOG.info( + "Update the optimizer with: %d configs, %d scores, %d status values", + len(configs or []), + len(scores or []), + len(status or []), + ) if len(configs or []) != len(scores or []): raise ValueError("Numbers of configs and scores do not match.") if status is not None and len(configs or []) != len(status or []): @@ -271,8 +283,12 @@ def suggest(self) -> TunableGroups: return self._tunables.copy() @abstractmethod - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: """ Register the observation for the given configuration. @@ -293,15 +309,16 @@ def register(self, tunables: TunableGroups, status: Status, Benchmark scores extracted (and possibly transformed) from the dataframe that's being MINIMIZED. """ - _LOG.info("Iteration %d :: Register: %s = %s score: %s", - self._iter, tunables, status, score) + _LOG.info( + "Iteration %d :: Register: %s = %s score: %s", self._iter, tunables, status, score + ) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) - def _get_scores(self, status: Status, - scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] - ) -> Optional[Dict[str, float]]: + def _get_scores( + self, status: Status, scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] + ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -330,7 +347,7 @@ def _get_scores(self, status: Status, assert scores is not None target_metrics: Dict[str, float] = {} - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): val = scores[opt_target] assert val is not None target_metrics[opt_target] = float(val) * opt_dir @@ -345,7 +362,9 @@ def not_converged(self) -> bool: return self._iter < self._max_iter @abstractmethod - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 62341c613d..a98edb463b 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -48,7 +48,8 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 +) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, wrapped in a ConfigurationSpace for composability. @@ -71,27 +72,28 @@ def _tunable_to_configspace( meta = {"group": group_name, "cost": cost} # {"scaling": ""} if tunable.type == "categorical": - return ConfigurationSpace({ - tunable.name: CategoricalHyperparameter( - name=tunable.name, - choices=tunable.categories, - weights=_normalize_weights(tunable.weights) if tunable.weights else None, - default_value=tunable.default, - meta=meta) - }) + return ConfigurationSpace( + { + tunable.name: CategoricalHyperparameter( + name=tunable.name, + choices=tunable.categories, + weights=_normalize_weights(tunable.weights) if tunable.weights else None, + default_value=tunable.default, + meta=meta, + ) + } + ) distribution: Union[Uniform, Normal, Beta, None] = None if tunable.distribution == "uniform": distribution = Uniform() elif tunable.distribution == "normal": distribution = Normal( - mu=tunable.distribution_params["mu"], - sigma=tunable.distribution_params["sigma"] + mu=tunable.distribution_params["mu"], sigma=tunable.distribution_params["sigma"] ) elif tunable.distribution == "beta": distribution = Beta( - alpha=tunable.distribution_params["alpha"], - beta=tunable.distribution_params["beta"] + alpha=tunable.distribution_params["alpha"], beta=tunable.distribution_params["beta"] ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -103,22 +105,26 @@ def _tunable_to_configspace( log=bool(tunable.is_log), q=nullable(int, tunable.quantization), distribution=distribution, - default=(int(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + int(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) elif tunable.type == "float": range_hp = Float( name=tunable.name, bounds=tunable.range, log=bool(tunable.is_log), - q=tunable.quantization, # type: ignore[arg-type] + q=tunable.quantization, # type: ignore[arg-type] distribution=distribution, # type: ignore[arg-type] - default=(float(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + float(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) else: raise TypeError(f"Invalid Parameter Type: {tunable.type}") @@ -136,31 +142,37 @@ def _tunable_to_configspace( # Create three hyperparameters: one for regular values, # one for special values, and one to choose between the two. (special_name, type_name) = special_param_names(tunable.name) - conf_space = ConfigurationSpace({ - tunable.name: range_hp, - special_name: CategoricalHyperparameter( - name=special_name, - choices=tunable.special, - weights=special_weights, - default_value=tunable.default if tunable.default in tunable.special else None, - meta=meta - ), - type_name: CategoricalHyperparameter( - name=type_name, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], - weights=switch_weights, - default_value=TunableValueKind.SPECIAL, - ), - }) - conf_space.add_condition(EqualsCondition( - conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)) - conf_space.add_condition(EqualsCondition( - conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)) + conf_space = ConfigurationSpace( + { + tunable.name: range_hp, + special_name: CategoricalHyperparameter( + name=special_name, + choices=tunable.special, + weights=special_weights, + default_value=tunable.default if tunable.default in tunable.special else None, + meta=meta, + ), + type_name: CategoricalHyperparameter( + name=type_name, + choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + weights=switch_weights, + default_value=TunableValueKind.SPECIAL, + ), + } + ) + conf_space.add_condition( + EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL) + ) + conf_space.add_condition( + EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE) + ) return conf_space -def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace: +def tunable_groups_to_configspace( + tunables: TunableGroups, seed: Optional[int] = None +) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -178,11 +190,14 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = A new ConfigurationSpace instance that corresponds to the input TunableGroups. """ space = ConfigurationSpace(seed=seed) - for (tunable, group) in tunables: + for tunable, group in tunables: space.add_configuration_space( - prefix="", delimiter="", + prefix="", + delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost())) + tunable, group.name, group.get_current_cost() + ), + ) return space @@ -201,7 +216,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: A ConfigSpace Configuration. """ values: Dict[str, TunableValue] = {} - for (tunable, _group) in tunables: + for tunable, _group in tunables: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) if tunable.value in tunable.special: @@ -222,10 +237,7 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: In particular, remove and keys suffixes added by `special_param_names`. """ data = data.copy() - specials = [ - special_param_name_strip(k) - for k in data.keys() if special_param_name_is_temp(k) - ] + specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)] for k in specials: (special_name, type_name) = special_param_names(k) if data[type_name] == TunableValueKind.SPECIAL: diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 4f207f5fc9..6e5700a37d 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -28,11 +28,13 @@ class GridSearchOptimizer(TrackBestOptimizer): Grid search optimizer. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) # Track the grid as a set of tuples of tunable values and reconstruct the @@ -51,11 +53,19 @@ def __init__(self, def _sanity_check(self) -> None: size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) if size == np.inf: - raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") + raise ValueError( + f"Unquantized tunables are not supported for grid search: {self._tunables}" + ) if size > 10000: - _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) + _LOG.warning( + "Large number %d of config points requested for grid search: %s", + size, + self._tunables, + ) if size > self._max_iter: - _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) + _LOG.warning( + "Grid search size %d, is greater than max iterations %d", size, self._max_iter + ) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: """ @@ -68,12 +78,14 @@ def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], Non # names instead of the order given by TunableGroups. configs = [ configspace_data_to_tunable_values(dict(config)) - for config in - generate_grid(self.config_space, { - tunable.name: int(tunable.cardinality) - for (tunable, _group) in self._tunables - if tunable.quantization or tunable.type == "int" - }) + for config in generate_grid( + self.config_space, + { + tunable.name: int(tunable.cardinality) + for (tunable, _group) in self._tunables + if tunable.quantization or tunable.type == "int" + }, + ) ] names = set(tuple(configs.keys()) for configs in configs) assert len(names) == 1 @@ -103,15 +115,17 @@ def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: # See NOTEs above. return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -152,20 +166,32 @@ def suggest(self) -> TunableGroups: _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) try: - config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) + config = dict( + ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()) + ) self._suggested_configs.remove(tuple(config.values())) except KeyError: - _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) + _LOG.warning( + "Attempted to remove missing config (previously registered?) from suggested set: %s", + tunables, + ) return registered_score def not_converged(self) -> bool: if self._iter > self._max_iter: if bool(self._pending_configs): - _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", - len(self._pending_configs), list(self._pending_configs.keys())) + _LOG.warning( + "Exceeded max iterations, but still have %d pending configs: %s", + len(self._pending_configs), + list(self._pending_configs.keys()), + ) return False return bool(self._pending_configs) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index d7d50f1ca5..a13ebe8d10 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -40,35 +40,41 @@ class MlosCoreOptimizer(Optimizer): A wrapper class for the mlos_core optimizers. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) - opt_type = getattr(OptimizerType, self._config.pop( - 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name)) + opt_type = getattr( + OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name) + ) if opt_type == OptimizerType.SMAC: - output_directory = self._config.get('output_directory') + output_directory = self._config.get("output_directory") if output_directory is not None: # If output_directory is specified, turn it into an absolute path. - self._config['output_directory'] = os.path.abspath(output_directory) + self._config["output_directory"] = os.path.abspath(output_directory) else: - _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.") + _LOG.warning( + "SMAC optimizer output_directory was null. SMAC will use a temporary directory." + ) # Make sure max_trials >= max_iterations. - if 'max_trials' not in self._config: - self._config['max_trials'] = self._max_iter - assert int(self._config['max_trials']) >= self._max_iter, \ - f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" + if "max_trials" not in self._config: + self._config["max_trials"] = self._max_iter + assert ( + int(self._config["max_trials"]) >= self._max_iter + ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" - if 'run_name' not in self._config and self.experiment_id: - self._config['run_name'] = self.experiment_id + if "run_name" not in self._config and self.experiment_id: + self._config["run_name"] = self.experiment_id - space_adapter_type = self._config.pop('space_adapter_type', None) - space_adapter_config = self._config.pop('space_adapter_config', {}) + space_adapter_type = self._config.pop("space_adapter_type", None) + space_adapter_config = self._config.pop("space_adapter_config", {}) if space_adapter_type is not None: space_adapter_type = getattr(SpaceAdapterType, space_adapter_type) @@ -82,9 +88,12 @@ def __init__(self, space_adapter_kwargs=space_adapter_config, ) - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: self._opt.cleanup() return super().__exit__(ex_type, ex_val, ex_tb) @@ -92,10 +101,12 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], def name(self) -> str: return f"{self.__class__.__name__}:{self._opt.__class__.__name__}" - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False @@ -103,7 +114,8 @@ def bulk_register(self, df_configs = self._to_df(configs) # Impute missing values, if necessary df_scores = self._adjust_signs_df( - pd.DataFrame([{} if score is None else score for score in scores])) + pd.DataFrame([{} if score is None else score for score in scores]) + ) opt_targets = list(self._opt_targets) if status is not None: @@ -130,7 +142,7 @@ def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: """ In-place adjust the signs of the scores for MINIMIZATION problem. """ - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores @@ -152,7 +164,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: df_configs = pd.DataFrame(configs) tunables_names = list(self._tunables.get_param_values().keys()) missing_cols = set(tunables_names).difference(df_configs.columns) - for (tunable, _group) in self._tunables: + for tunable, _group in self._tunables: if tunable.name in missing_cols: df_configs[tunable.name] = tunable.default else: @@ -184,22 +196,31 @@ def suggest(self) -> TunableGroups: df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults) self._start_with_defaults = False _LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config) - return tunables.assign( - configspace_data_to_tunable_values(df_config.loc[0].to_dict())) - - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: - registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION + return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict())) + + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: + registered_score = super().register( + tunables, status, score + ) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None df_config = self._to_df([tunables.get_param_values()]) _LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config) # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float)) + self._opt.register( + configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) + ) return registered_score - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: (df_config, df_score, _df_context) = self._opt.get_best_observations() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index ada4411b58..8dd13eb182 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -24,11 +24,13 @@ class MockOptimizer(TrackBestOptimizer): Mock optimizer to test the Environment API. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) rnd = random.Random(self.seed) self._random: Dict[str, Callable[[Tunable], TunableValue]] = { @@ -37,15 +39,17 @@ def __init__(self, "int": lambda tunable: rnd.randint(*tunable.range), } - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -62,7 +66,7 @@ def suggest(self) -> TunableGroups: _LOG.info("Use default tunable values") self._start_with_defaults = False else: - for (tunable, _group) in tunables: + for tunable, _group in tunables: tunable.value = self._random[tunable.type](tunable) _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index 9ad1070c46..b7a14f8af2 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -24,11 +24,13 @@ class OneShotOptimizer(MockOptimizer): # TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344 - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) _LOG.info("Run a single iteration for: %s", self._tunables) self._max_iter = 1 # Always run for just one iteration. diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index 32a23142e3..0fd54b2dfa 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -24,17 +24,23 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): Base Optimizer class that keeps track of the best score and configuration. """ - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) self._best_config: Optional[TunableGroups] = None self._best_score: Optional[Dict[str, float]] = None - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) if status.is_succeeded() and self._is_better(registered_score): self._best_score = registered_score @@ -48,7 +54,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: if self._best_score is None: return True assert registered_score is not None - for (opt_target, best_score) in self._best_score.items(): + for opt_target, best_score in self._best_score.items(): score = registered_score[opt_target] if score < best_score: return True @@ -56,7 +62,9 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: return False return False - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) score = self._get_scores(Status.SUCCEEDED, self._best_score) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index a7912688a1..7f26851c6b 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,16 +22,19 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[ + str + ] # pylint: disable=protected-access,disable=unsubscriptable-object else: - EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access + EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 -if sys.platform == 'win32': +if sys.platform == "win32": import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) + environ: EnvironType = nt.environ else: environ: EnvironType = os.environ -__all__ = ['environ'] +__all__ = ["environ"] diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py index 85c8c2b0c5..57c48a87b9 100755 --- a/mlos_bench/mlos_bench/run.py +++ b/mlos_bench/mlos_bench/run.py @@ -20,8 +20,9 @@ _LOG = logging.getLogger(__name__) -def _main(argv: Optional[List[str]] = None - ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: +def _main( + argv: Optional[List[str]] = None, +) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv) diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index c54e3c0efc..c53d11231d 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ - 'Scheduler', - 'SyncScheduler', + "Scheduler", + "SyncScheduler", ] diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 0b6733e423..210e2784a5 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -31,13 +31,16 @@ class Scheduler(metaclass=ABCMeta): Base class for the optimization loop scheduling policies. """ - def __init__(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: Storage, - root_env_config: str): + def __init__( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: Storage, + root_env_config: str, + ): """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service @@ -60,8 +63,9 @@ def __init__(self, *, Path to the root environment configuration. """ self.global_config = global_config - config = merge_parameters(dest=config.copy(), source=global_config, - required_keys=["experiment_id", "trial_id"]) + config = merge_parameters( + dest=config.copy(), source=global_config, required_keys=["experiment_id", "trial_id"] + ) self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) @@ -71,7 +75,9 @@ def __init__(self, *, self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) if self._trial_config_repeat_count <= 0: - raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") + raise ValueError( + f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" + ) self._do_teardown = bool(config.get("teardown", True)) @@ -95,7 +101,7 @@ def __repr__(self) -> str: """ return self.__class__.__name__ - def __enter__(self) -> 'Scheduler': + def __enter__(self) -> "Scheduler": """ Enter the scheduler's context. """ @@ -117,10 +123,12 @@ def __enter__(self) -> 'Scheduler': ).__enter__() return self - def __exit__(self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the context of the scheduler. """ @@ -142,8 +150,12 @@ def start(self) -> None: Start the optimization loop. """ assert self.experiment is not None - _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, self.environment, self.optimizer) + _LOG.info( + "START: Experiment: %s Env: %s Optimizer: %s", + self.experiment, + self.environment, + self.optimizer, + ) if _LOG.isEnabledFor(logging.INFO): _LOG.info("Root Environment:\n%s", self.environment.pprint()) @@ -204,27 +216,33 @@ def schedule_trial(self, tunables: TunableGroups) -> None: Add a configuration to the queue of trials. """ for repeat_i in range(1, self._trial_config_repeat_count + 1): - self._add_trial_to_queue(tunables, config={ - # Add some additional metadata to track for the trial such as the - # optimizer config used. - # Note: these values are unfortunately mutable at the moment. - # Consider them as hints of what the config was the trial *started*. - # It is possible that the experiment configs were changed - # between resuming the experiment (since that is not currently - # prevented). - "optimizer": self.optimizer.name, - "repeat_i": repeat_i, - "is_defaults": tunables.is_defaults, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(self.optimizer.targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) + self._add_trial_to_queue( + tunables, + config={ + # Add some additional metadata to track for the trial such as the + # optimizer config used. + # Note: these values are unfortunately mutable at the moment. + # Consider them as hints of what the config was the trial *started*. + # It is possible that the experiment configs were changed + # between resuming the experiment (since that is not currently + # prevented). + "optimizer": self.optimizer.name, + "repeat_i": repeat_i, + "is_defaults": tunables.is_defaults, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(self.optimizer.targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) - def _add_trial_to_queue(self, tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> None: + def _add_trial_to_queue( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: """ Add a configuration to the queue of trials. A wrapper for the `Experiment.new_trial` method. diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index a73a493533..3e196d4d4f 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -53,7 +53,9 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = self.environment.run() # Block and wait for the final result. + (status, timestamp, results) = ( + self.environment.run() + ) # Block and wait for the final result. _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index bcc7d02d6f..dacbb88126 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -11,7 +11,7 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'Service', - 'FileShareService', - 'LocalExecService', + "Service", + "FileShareService", + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index f00a7a1a00..63c222ee45 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -21,10 +21,13 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): An abstract base of all file shares. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share with a given config. @@ -42,12 +45,16 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) @abstractmethod - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -65,11 +72,18 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ params = params or {} - _LOG.info("Download from File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", remote_path, local_path, params) + _LOG.info( + "Download from File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + remote_path, + local_path, + params, + ) @abstractmethod - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. @@ -86,5 +100,10 @@ def upload(self, params: dict, local_path: str, remote_path: str, recursive: boo if True (the default), upload the entire directory tree. """ params = params or {} - _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", local_path, remote_path, params) + _LOG.info( + "Upload to File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + local_path, + remote_path, + params, + ) diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index e7c9365bf7..65725b6288 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -26,11 +26,13 @@ class Service: """ @classmethod - def new(cls, - class_name: str, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def new( + cls, + class_name: str, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Factory method for a new service with a given config. @@ -57,11 +59,13 @@ def new(cls, assert issubclass(cls, Service) return instantiate_from_config(cls, class_name, config, global_config, parent) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new service with a given config. @@ -101,8 +105,10 @@ def __init__(self, _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) @staticmethod - def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], - local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: + def merge_methods( + ext_methods: Union[Dict[str, Callable], List[Callable], None], + local_methods: Union[Dict[str, Callable], List[Callable]], + ) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. This function is usually called by the derived class constructor @@ -138,9 +144,12 @@ def __enter__(self) -> "Service": self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the Service mix-in context. @@ -177,9 +186,12 @@ def _enter_context(self) -> "Service": self._in_context = True return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exits the context for this particular Service instance. @@ -265,10 +277,11 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None # Unfortunately, by creating a set, we may destroy the ability to # preserve the context enter/exit order, but hopefully it doesn't # matter. - svc_method.__self__ for _, svc_method in self._service_methods.items() + svc_method.__self__ + for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) + if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) } def export(self) -> Dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index cac3216d61..55d8e67527 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -61,11 +61,13 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of config persistence service. @@ -82,17 +84,22 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.resolve_path, - self.load_config, - self.prepare_class_load, - self.build_service, - self.build_environment, - self.load_services, - self.load_environment, - self.load_environment_list, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.resolve_path, + self.load_config, + self.prepare_class_load, + self.build_service, + self.build_environment, + self.load_services, + self.load_environment, + self.load_environment_list, + ], + ), ) self._config_loader_service = self @@ -120,8 +127,7 @@ def config_paths(self) -> List[str]: """ return list(self._config_path) # make a copy to avoid modifications - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -151,10 +157,11 @@ def resolve_path(self, file_path: str, _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config(self, - json_file_name: str, - schema_type: Optional[ConfigSchema], - ) -> Dict[str, Any]: + def load_config( + self, + json_file_name: str, + schema_type: Optional[ConfigSchema], + ) -> Dict[str, Any]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -174,16 +181,22 @@ def load_config(self, """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) - with open(json_file_name, mode='r', encoding='utf-8') as fh_json: + with open(json_file_name, mode="r", encoding="utf-8") as fh_json: config = json5.load(fh_json) if schema_type is not None: try: schema_type.validate(config) except (ValidationError, SchemaError) as ex: - _LOG.error("Failed to validate config %s against schema type %s at %s", - json_file_name, schema_type.name, schema_type.value) - raise ValueError(f"Failed to validate config {json_file_name} against " + - f"schema type {schema_type.name} at {schema_type.value}") from ex + _LOG.error( + "Failed to validate config %s against schema type %s at %s", + json_file_name, + schema_type.name, + schema_type.value, + ) + raise ValueError( + f"Failed to validate config {json_file_name} against " + + f"schema type {schema_type.name} at {schema_type.value}" + ) from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated # them to avoid passing them on to other objects @@ -194,11 +207,14 @@ def load_config(self, del config["$schema"] else: _LOG.warning("Config %s is not validated against a schema.", json_file_name) - return config # type: ignore[no-any-return] + return config # type: ignore[no-any-return] - def prepare_class_load(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: + def prepare_class_load( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + ) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. Mix-in the global parameters and resolve the local file system paths, @@ -241,16 +257,22 @@ def prepare_class_load(self, config: Dict[str, Any], raise ValueError(f"Parameter {key} must be a string or a list") if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", + class_name, + json.dumps(class_config, indent=2), + ) return (class_name, class_config) - def build_optimizer(self, *, - tunables: TunableGroups, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> Optimizer: + def build_optimizer( + self, + *, + tunables: TunableGroups, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> Optimizer: """ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. @@ -279,18 +301,24 @@ def build_optimizer(self, *, if tunables_path is not None: tunables = self._load_tunables(tunables_path, tunables) (class_name, class_config) = self.prepare_class_load(config, global_config) - inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract] - tunables=tunables, - config=class_config, - global_config=global_config, - service=service) + inst = instantiate_from_config( + Optimizer, + class_name, # type: ignore[type-abstract] + tunables=tunables, + config=class_config, + global_config=global_config, + service=service, + ) _LOG.info("Created: Optimizer %s", inst) return inst - def build_storage(self, *, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> "Storage": + def build_storage( + self, + *, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> "Storage": """ Instantiation of mlos_bench Storage objects. @@ -312,20 +340,27 @@ def build_storage(self, *, from mlos_bench.storage.base_storage import ( Storage, # pylint: disable=import-outside-toplevel ) - inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - service=service) + + inst = instantiate_from_config( + Storage, + class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + service=service, + ) _LOG.info("Created: Storage %s", inst) return inst - def build_scheduler(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: "Storage", - root_env_config: str) -> "Scheduler": + def build_scheduler( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: "Storage", + root_env_config: str, + ) -> "Scheduler": """ Instantiation of mlos_bench Scheduler. @@ -353,22 +388,28 @@ def build_scheduler(self, *, from mlos_bench.schedulers.base_scheduler import ( Scheduler, # pylint: disable=import-outside-toplevel ) - inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - environment=environment, - optimizer=optimizer, - storage=storage, - root_env_config=root_env_config) + + inst = instantiate_from_config( + Scheduler, + class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + environment=environment, + optimizer=optimizer, + storage=storage, + root_env_config=root_env_config, + ) _LOG.info("Created: Scheduler %s", inst) return inst - def build_environment(self, # pylint: disable=too-many-arguments - config: Dict[str, Any], - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def build_environment( + self, # pylint: disable=too-many-arguments + config: Dict[str, Any], + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Factory method for a new environment with a given config. @@ -408,16 +449,24 @@ def build_environment(self, # pylint: disable=too-many-arguments tunables = self._load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new(env_name=env_name, class_name=env_class, - config=env_config, global_config=global_config, - tunables=tunables, service=service) + env = Environment.new( + env_name=env_name, + class_name=env_class, + config=env_config, + global_config=global_config, + tunables=tunables, + service=service, + ) _LOG.info("Created env: %s :: %s", env_name, env) return env - def _build_standalone_service(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_standalone_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -442,9 +491,12 @@ def _build_standalone_service(self, config: Dict[str, Any], _LOG.info("Created service: %s", service) return service - def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_composite_service( + self, + config_list: Iterable[Dict[str, Any]], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -470,18 +522,21 @@ def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], service.register(parent.export()) for config in config_list: - service.register(self._build_standalone_service( - config, global_config, service).export()) + service.register( + self._build_standalone_service(config, global_config, service).export() + ) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Created mix-in service: %s", service) return service - def build_service(self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def build_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -503,8 +558,7 @@ def build_service(self, services from the list plus the parent mix-in. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Build service from config:\n%s", - json.dumps(config, indent=2)) + _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2)) assert isinstance(config, dict) config_list: List[Dict[str, Any]] @@ -519,12 +573,14 @@ def build_service(self, return self._build_composite_service(config_list, global_config, parent) - def load_environment(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def load_environment( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Load and build new environment from the config file. @@ -551,12 +607,14 @@ def load_environment(self, # pylint: disable=too-many-arguments assert isinstance(config, dict) return self.build_environment(config, tunables, global_config, parent_args, service) - def load_environment_list(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> List[Environment]: + def load_environment_list( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -581,13 +639,14 @@ def load_environment_list(self, # pylint: disable=too-many-arguments A list of new benchmarking environments. """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) - return [ - self.build_environment(config, tunables, global_config, parent_args, service) - ] + return [self.build_environment(config, tunables, global_config, parent_args, service)] - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -606,16 +665,16 @@ def load_services(self, json_file_names: Iterable[str], service : Service A collection of service methods. """ - _LOG.info("Load services: %s parent: %s", - json_file_names, parent.__class__.__name__) + _LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__) service = Service({}, global_config, parent) for fname in json_file_names: config = self.load_config(fname, ConfigSchema.SERVICE) service.register(self.build_service(config, global_config, service).export()) return service - def _load_tunables(self, json_file_names: Iterable[str], - parent: TunableGroups) -> TunableGroups: + def _load_tunables( + self, json_file_names: Iterable[str], parent: TunableGroups + ) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent TunableGroup. diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index abb87c8b52..b9d0c267c1 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -9,5 +9,5 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'LocalExecService', + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 47534be7b1..0486ab7c80 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -79,11 +79,13 @@ class LocalExecService(TempDirContextService, SupportsLocalExec): due to reduced dependency management complications vs the target environment. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service to run scripts locally. @@ -100,14 +102,16 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, self.merge_methods(methods, [self.local_exec]) ) self.abort_on_error = self.config.get("abort_on_error", True) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -175,9 +179,9 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: subcmd_tokens.insert(0, sys.executable) return subcmd_tokens - def _local_exec_script(self, script_line: str, - env_params: Optional[Mapping[str, "TunableValue"]], - cwd: str) -> Tuple[int, str, str]: + def _local_exec_script( + self, script_line: str, env_params: Optional[Mapping[str, "TunableValue"]], cwd: str + ) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -206,7 +210,7 @@ def _local_exec_script(self, script_line: str, if env_params: env = {key: str(val) for (key, val) in env_params.items()} - if sys.platform == 'win32': + if sys.platform == "win32": # A hack to run Python on Windows with env variables set: env_copy = environ.copy() env_copy["PYTHONPATH"] = "" @@ -214,7 +218,7 @@ def _local_exec_script(self, script_line: str, env = env_copy try: - if sys.platform != 'win32': + if sys.platform != "win32": cmd = [" ".join(cmd)] _LOG.info("Run: %s", cmd) @@ -222,8 +226,15 @@ def _local_exec_script(self, script_line: str, _LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env)) _LOG.debug("Current working dir: %s", cwd) - proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, - text=True, check=False, capture_output=True) + proc = subprocess.run( + cmd, + env=env or None, + cwd=cwd, + shell=True, + text=True, + check=False, + capture_output=True, + ) _LOG.debug("Run: return code = %d", proc.returncode) return (proc.returncode, proc.stdout, proc.stderr) diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index a0cf3e0e57..8512b5d282 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -28,11 +28,13 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): This class is not supposed to be used as a standalone service. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service that provides temporary directory context for local exec service. @@ -50,8 +52,7 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.temp_dir_context]) + config, global_config, parent, self.merge_methods(methods, [self.temp_dir_context]) ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: @@ -61,7 +62,9 @@ def __init__(self, self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir) _LOG.info("%s: temp dir: %s", self, self._temp_dir) - def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 61a6c74942..12fe62eeb7 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -13,9 +13,9 @@ from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService __all__ = [ - 'AzureAuthService', - 'AzureFileShareService', - 'AzureNetworkService', - 'AzureSaaSConfigService', - 'AzureVMService', + "AzureAuthService", + "AzureFileShareService", + "AzureNetworkService", + "AzureSaaSConfigService", + "AzureVMService", ] diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 4121446caf..9074353221 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -27,13 +27,15 @@ class AzureAuthService(Service, SupportsAuth): Helper methods to get access to Azure services. """ - _REQ_INTERVAL = 300 # = 5 min - - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + _REQ_INTERVAL = 300 # = 5 min + + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure authentication services proxy. @@ -50,11 +52,16 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) # This parameter can come from command line as strings, so conversion is needed. @@ -70,12 +77,13 @@ def __init__(self, # Verify info required for SP auth early if "spClientId" in self.config: check_required_params( - self.config, { + self.config, + { "spClientId", "keyVaultName", "certName", "tenant", - } + }, ) def _init_sp(self) -> None: @@ -104,7 +112,9 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) + self._cred = azure_id.CertificateCredential( + tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes + ) def get_access_token(self) -> str: """ diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 9f2b504aff..3673baca76 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -29,9 +29,9 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): Helper methods to manage and deploy Azure resources via REST APIs. """ - _POLL_INTERVAL = 4 # seconds - _POLL_TIMEOUT = 300 # seconds - _REQUEST_TIMEOUT = 5 # seconds + _POLL_INTERVAL = 4 # seconds + _POLL_TIMEOUT = 300 # seconds + _REQUEST_TIMEOUT = 5 # seconds _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) @@ -39,19 +39,21 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): # https://docs.microsoft.com/en-us/rest/api/resources/deployments _URL_DEPLOY = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Resources" + - "/deployments/{deployment_name}" + - "?api-version=2022-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Resources" + + "/deployments/{deployment_name}" + + "?api-version=2022-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an Azure Services proxy. @@ -69,32 +71,44 @@ def __init__(self, """ super().__init__(config, global_config, parent, methods) - check_required_params(self.config, [ - "subscription", - "resourceGroup", - ]) + check_required_params( + self.config, + [ + "subscription", + "resourceGroup", + ], + ) # These parameters can come from command line as strings, so conversion is needed. self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) - self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) + self._total_retries = int( + self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) + ) + self._backoff_factor = float( + self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) + ) self._deploy_template = {} self._deploy_params = {} if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config['deploymentTemplatePath'], schema_type=None) + self.config["deploymentTemplatePath"], schema_type=None + ) assert template is not None and isinstance(template, dict) self._deploy_template = template # Allow for recursive variable expansion as we do with global params and const_args. - deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) + deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( + extra_source_dict=global_config + ) self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) else: - _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") + _LOG.info( + "No deploymentTemplatePath provided. Deployment services will be unavailable." + ) @property def deploy_params(self) -> dict: @@ -129,7 +143,8 @@ def _get_session(self, params: dict) -> requests.Session: session = requests.Session() session.mount( "https://", - HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) + HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)), + ) session.headers.update(self._get_headers()) return session @@ -137,8 +152,9 @@ def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @staticmethod @@ -234,9 +250,11 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: return (Status.FAILED, {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) if response.status_code == 200: output = response.json() @@ -269,12 +287,16 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ params = self._set_default_params(params) - _LOG.info("Wait for %s to %s", params.get("deploymentName"), - "provision" if is_setup else "deprovision") + _LOG.info( + "Wait for %s to %s", + params.get("deploymentName"), + "provision" if is_setup else "deprovision", + ) return self._wait_while(self._check_deployment, Status.PENDING, params) - def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], - loop_status: Status, params: dict) -> Tuple[Status, dict]: + def _wait_while( + self, func: Callable[[dict], Tuple[Status, dict]], loop_status: Status, params: dict + ) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -296,12 +318,18 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) poll_period = params.get("pollInterval", self._poll_interval) - _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", - config["deploymentName"], loop_status, poll_period, self._poll_timeout) + _LOG.debug( + "Wait for %s status %s :: poll %.2f timeout %d s", + config["deploymentName"], + loop_status, + poll_period, + self._poll_timeout, + ) ts_timeout = time.time() + self._poll_timeout poll_delay = poll_period @@ -325,7 +353,9 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment( + self, params: dict + ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -351,7 +381,7 @@ def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: di "subscription", "resourceGroup", "deploymentName", - ] + ], ) _LOG.info("Check deployment: %s", config["deploymentName"]) @@ -412,13 +442,18 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: if not self._deploy_template: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) - config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + config = merge_parameters( + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) params = merge_parameters(dest=self._deploy_params.copy(), source=params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Deploy: %s merged params ::\n%s", - config["deploymentName"], json.dumps(params, indent=2)) + _LOG.debug( + "Deploy: %s merged params ::\n%s", + config["deploymentName"], + json.dumps(params, indent=2), + ) url = self._URL_DEPLOY.format( subscription=config["subscription"], @@ -431,22 +466,26 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: "mode": "Incremental", "template": self._deploy_template, "parameters": { - key: {"value": val} for (key, val) in params.items() + key: {"value": val} + for (key, val) in params.items() if key in self._deploy_template.get("parameters", {}) - } + }, } } if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.put(url, json=json_req, - headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put( + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 6ccd4ba09d..717086b52e 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -27,11 +27,13 @@ class AzureFileShareService(FileShareService): _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share Service for Azure environments with a given config. @@ -49,16 +51,19 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) check_required_params( - self.config, { + self.config, + { "storageAccountName", "storageFileShareName", "storageAccountKey", - } + }, ) self._share_client = ShareClient.from_share_url( @@ -69,7 +74,9 @@ def __init__(self, credential=self.config["storageAccountKey"], ) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -94,7 +101,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b # Translate into non-Azure exception: raise FileNotFoundError(f"Cannot download: {remote_path}") from ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index d65ee02cfd..4ba8bd3903 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -32,20 +32,22 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Network" + - "/virtualNetwork/{vnet_name}" + - "/delete" + - "?api-version=2023-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Network" + + "/virtualNetwork/{vnet_name}" + + "/delete" + + "?api-version=2023-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure Network services proxy. @@ -62,25 +64,34 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsNetworkProvisioning - self.provision_network, - self.deprovision_network, - self.wait_network_deployment, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsNetworkProvisioning + self.provision_network, + self.deprovision_network, + self.wait_network_deployment, + ], + ), ) if not self._deploy_template: - raise ValueError("AzureNetworkService requires a deployment template:\n" - + f"config={config}\nglobal_config={global_config}") + raise ValueError( + "AzureNetworkService requires a deployment template:\n" + + f"config={config}\nglobal_config={global_config}" + ) - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + ) return params def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -151,15 +162,18 @@ def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple "resourceGroup", "deploymentName", "vnetName", - ] + ], ) _LOG.info("Deprovision Network: %s", config["vnetName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) - (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vnet_name=config["vnetName"], - )) + (status, results) = self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vnet_name=config["vnetName"], + ), + ) if ignore_errors and status == Status.FAILED: _LOG.warning("Ignoring error: %s", results) status = Status.SUCCEEDED diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index a92d279a6d..e7f626f505 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -32,20 +32,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations _URL_CONFIGURE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/{provider}" + - "/{server_type}/{vm_name}" + - "/{update}" + - "?api-version={api_version}" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/{provider}" + + "/{server_type}/{vm_name}" + + "/{update}" + + "?api-version={api_version}" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure services proxy. @@ -62,18 +64,20 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.configure, - self.is_config_pending - ]) + config, + global_config, + parent, + self.merge_methods(methods, [self.configure, self.is_config_pending]), ) - check_required_params(self.config, { - "subscription", - "resourceGroup", - "provider", - }) + check_required_params( + self.config, + { + "subscription", + "resourceGroup", + "provider", + }, + ) # Provide sane defaults for known DB providers. provider = self.config.get("provider") @@ -117,8 +121,7 @@ def __init__(self, # These parameters can come from command line as strings, so conversion is needed. self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service. @@ -156,33 +159,38 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_get.format(vm_name=config["vmName"]) _LOG.debug("Request: GET %s", url) - response = requests.put( - url, headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) if response.status_code != 200: return (Status.FAILED, {}) # Currently, Azure Flex servers require a VM reboot. - return (Status.SUCCEEDED, {"isConfigPendingReboot": any( - {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] - for val in response.json()['value'] - )}) + return ( + Status.SUCCEEDED, + { + "isConfigPendingReboot": any( + {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]] + for val in response.json()["value"] + ) + }, + ) def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() - def _config_one(self, config: Dict[str, Any], - param_name: str, param_value: Any) -> Tuple[Status, dict]: + def _config_one( + self, config: Dict[str, Any], param_name: str, param_value: Any + ) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -201,13 +209,15 @@ def _config_one(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) _LOG.debug("Request: PUT %s", url) - response = requests.put(url, headers=self._get_headers(), - json={"properties": {"value": str(param_value)}}, - timeout=self._request_timeout) + response = requests.put( + url, + headers=self._get_headers(), + json={"properties": {"value": str(param_value)}}, + timeout=self._request_timeout, + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) @@ -215,8 +225,7 @@ def _config_one(self, config: Dict[str, Any], return (Status.SUCCEEDED, {}) return (Status.FAILED, {}) - def _config_many(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service one-by-one. (If batch API is not available for it). @@ -234,14 +243,13 @@ def _config_many(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - for (param_name, param_value) in params.items(): + for param_name, param_value in params.items(): (status, result) = self._config_one(config, param_name, param_value) if not status.is_succeeded(): return (status, result) return (Status.SUCCEEDED, {}) - def _config_batch(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Batch update the parameters of an Azure DB service. @@ -258,19 +266,18 @@ def _config_batch(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"]) json_req = { "value": [ - {"name": key, "properties": {"value": str(val)}} - for (key, val) in params.items() + {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items() ], # "resetAllToDefault": "True" } _LOG.debug("Request: POST %s", url) - response = requests.post(url, headers=self._get_headers(), - json=json_req, timeout=self._request_timeout) + response = requests.post( + url, headers=self._get_headers(), json=json_req, timeout=self._request_timeout + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index ddce3cc935..effb0f9499 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -26,7 +26,13 @@ _LOG = logging.getLogger(__name__) -class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): +class AzureVMService( + AzureDeploymentService, + SupportsHostProvisioning, + SupportsHostOps, + SupportsOSOps, + SupportsRemoteExec, +): """ Helper methods to manage VMs on Azure. """ @@ -38,35 +44,35 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/start" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/start" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/powerOff" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/powerOff" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/deallocate" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/deallocate" + + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -88,31 +94,33 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/restart" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/restart" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/runCommand" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/runCommand" + + "?api-version=2022-03-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure VM services proxy. @@ -129,26 +137,31 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsHostProvisioning - self.provision_host, - self.deprovision_host, - self.deallocate_host, - self.wait_host_deployment, - # SupportsHostOps - self.start_host, - self.stop_host, - self.restart_host, - self.wait_host_operation, - # SupportsOSOps - self.shutdown, - self.reboot, - self.wait_os_operation, - # SupportsRemoteExec - self.remote_exec, - self.get_remote_exec_results, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsHostProvisioning + self.provision_host, + self.deprovision_host, + self.deallocate_host, + self.wait_host_deployment, + # SupportsHostOps + self.start_host, + self.stop_host, + self.restart_host, + self.wait_host_operation, + # SupportsOSOps + self.shutdown, + self.reboot, + self.wait_os_operation, + # SupportsRemoteExec + self.remote_exec, + self.get_remote_exec_results, + ], + ), ) # As a convenience, allow reading customData out of a file, rather than @@ -157,19 +170,23 @@ def __init__(self, # can be done using the `base64()` string function inside the ARM template. self._custom_data_file = self.config.get("customDataFile", None) if self._custom_data_file: - if self._deploy_params.get('customData', None): + if self._deploy_params.get("customData", None): raise ValueError("Both customDataFile and customData are specified.") - self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file) - with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh: + self._custom_data_file = self.config_loader_service.resolve_path( + self._custom_data_file + ) + with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: self._deploy_params["customData"] = custom_data_fh.read() - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + ) return params def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -264,16 +281,19 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: "resourceGroup", "deploymentName", "vmName", - ] + ], ) _LOG.info("Deprovision VM: %s", config["vmName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) # TODO: Properly deprovision *all* resources specified in the ARM template. - return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -301,14 +321,17 @@ def deallocate_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Deallocate VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEALLOCATE.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def start_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -333,14 +356,17 @@ def start_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Start VM: %s :: %s", config["vmName"], params) - return self._azure_rest_api_post_helper(config, self._URL_START.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_START.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: """ @@ -367,14 +393,17 @@ def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Stop VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_STOP.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_STOP.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.stop_host(params, force) @@ -404,20 +433,24 @@ def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict] "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Reboot VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_REBOOT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple[Status, dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple[Status, dict]: """ Run a command on Azure VM. @@ -447,7 +480,7 @@ def remote_exec(self, script: Iterable[str], config: dict, "subscription", "resourceGroup", "vmName", - ] + ], ) if _LOG.isEnabledFor(logging.INFO): @@ -456,7 +489,7 @@ def remote_exec(self, script: Iterable[str], config: dict, json_req = { "commandId": "RunShellScript", "script": list(script), - "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()] + "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()], } url = self._URL_REXEC_RUN.format( @@ -469,12 +502,15 @@ def remote_exec(self, script: Iterable[str], config: dict, _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout) + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) @@ -482,10 +518,10 @@ def remote_exec(self, script: Iterable[str], config: dict, # TODO: extract the results from JSON response return (Status.SUCCEEDED, config) elif response.status_code == 202: - return (Status.PENDING, { - **config, - "asyncResultsUrl": response.headers.get("Azure-AsyncOperation") - }) + return ( + Status.PENDING, + {**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")}, + ) else: _LOG.error("Response: %s :: %s", response, response.text) # _LOG.error("Bad Request:\n%s", response.request.body) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f623cdfcc8..f136747f7f 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -31,9 +31,14 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" - async def _start_file_copy(self, params: dict, mode: CopyMode, - local_path: str, remote_path: str, - recursive: bool = True) -> None: + async def _start_file_copy( + self, + params: dict, + mode: CopyMode, + local_path: str, + remote_path: str, + recursive: bool = True, + ) -> None: # pylint: disable=too-many-arguments """ Starts a file copy operation @@ -73,40 +78,52 @@ async def _start_file_copy(self, params: dict, mode: CopyMode, raise ValueError(f"Unknown copy mode: {mode}") return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) + self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) + _LOG.error( + "Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex + ) if isinstance(ex, SFTPNoSuchFile) or ( - isinstance(ex, SFTPFailure) and ex.code == 4 - and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) + isinstance(ex, SFTPFailure) + and ex.code == 4 + and any( + msg.lower() in ex.reason.lower() + for msg in ("File not found", "No such file or directory") + ) ): _LOG.warning("File %s does not exist on %s", remote_path, params) raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex raise ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().upload(params, local_path, remote_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) + self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index a650ff0707..f04544eb05 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -29,11 +29,13 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an SSH Service. @@ -52,17 +54,25 @@ def __init__(self, # Same methods are also provided by the AzureVMService class # pylint: disable=duplicate-code super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.shutdown, - self.reboot, - self.wait_os_operation, - self.remote_exec, - self.get_remote_exec_results, - ])) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ], + ), + ) self._shell = self.config.get("ssh_shell", "/bin/bash") - async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: + async def _run_cmd( + self, params: dict, script: Iterable[str], env_params: dict + ) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. @@ -85,16 +95,19 @@ async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] - script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] + script_lines = env_script_lines + [ + line_split for line in script for line_split in line.splitlines() + ] # Note: connection.run() uses "exec" with a shell by default. - script_str = '\n'.join(script_lines) + script_str = "\n".join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) - return await connection.run(script_str, - check=False, - timeout=self._request_timeout, - env=env_params) + return await connection.run( + script_str, check=False, timeout=self._request_timeout, env=env_params + ) - def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -121,9 +134,11 @@ def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> source=config, required_keys=[ "ssh_hostname", - ] + ], + ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine( + self._run_cmd(config, script, env_params) ) - config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) return (Status.PENDING, config) def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: @@ -154,7 +169,11 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr return ( - Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, + ( + Status.SUCCEEDED + if result.exit_status == 0 and result.returncode == 0 + else Status.FAILED + ), { "stdout": stdout, "stderr": stderr, @@ -186,9 +205,9 @@ def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, d source=params, required_keys=[ "ssh_hostname", - ] + ], ) - cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) + cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) script = rf""" if [[ $EUID -ne 0 ]]; then sudo=$(command -v sudo) @@ -223,10 +242,10 @@ def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -h now', - 'poweroff', - 'halt -p', - 'systemctl poweroff', + "shutdown -h now", + "poweroff", + "halt -p", + "systemctl poweroff", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) @@ -248,11 +267,11 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -r now', - 'reboot', - 'halt --reboot', - 'systemctl reboot', - 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', + "shutdown -r now", + "reboot", + "halt --reboot", + "systemctl reboot", + "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 8bc90eb3da..64bb7d9788 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -50,8 +50,8 @@ class SshClient(asyncssh.SSHClient): reconnect for each command. """ - _CONNECTION_PENDING = 'INIT' - _CONNECTION_LOST = 'LOST' + _CONNECTION_PENDING = "INIT" + _CONNECTION_LOST = "LOST" def __init__(self, *args: tuple, **kwargs: dict): self._connection_id: str = SshClient._CONNECTION_PENDING @@ -65,7 +65,7 @@ def __repr__(self) -> str: @staticmethod def id_from_connection(connection: SSHClientConnection) -> str: """Gets a unique id repr for the connection.""" - return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access @staticmethod def id_from_params(connect_params: dict) -> str: @@ -79,8 +79,9 @@ def connection_made(self, conn: SSHClientConnection) -> None: Changes the connection_id from _CONNECTION_PENDING to a unique id repr. """ self._conn_event.clear() - _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ - # pylint: disable=protected-access + _LOG.debug( + "%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn + ) # pylint: disable=protected-access self._connection_id = SshClient.id_from_connection(conn) self._connection = conn self._conn_event.set() @@ -90,9 +91,19 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self._conn_event.clear() _LOG.debug("%s: %s", current_thread().name, "connection_lost") if exc is None: - _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: gracefully disconnected ssh from %s: %s", + current_thread().name, + self._connection_id, + exc, + ) else: - _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: ssh connection lost on %s: %s", + current_thread().name, + self._connection_id, + exc, + ) self._connection_id = SshClient._CONNECTION_LOST self._connection = None self._conn_event.set() @@ -145,7 +156,9 @@ def exit(self) -> None: warn(RuntimeWarning("SshClientCache lock was still held on exit.")) self._cache_lock.release() - async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: + async def get_client_connection( + self, connect_params: dict + ) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. @@ -168,13 +181,21 @@ async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientCo _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) connection = await client.connection() if not connection: - _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) + _LOG.debug( + "%s: Removing stale client connection %s from cache.", + current_thread().name, + connection_id, + ) self._cache.pop(connection_id) # Try to reconnect next. else: _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) if connection_id not in self._cache: - _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) + _LOG.debug( + "%s: Establishing client connection to %s", + current_thread().name, + connection_id, + ) connection, client = await asyncssh.create_connection(SshClient, **connect_params) assert isinstance(client, SshClient) self._cache[connection_id] = (connection, client) @@ -185,7 +206,7 @@ def cleanup(self) -> None: """ Closes all cached connections. """ - for (connection, _) in self._cache.values(): + for connection, _ in self._cache.values(): connection.close() self._cache = {} @@ -225,21 +246,23 @@ class SshService(Service, metaclass=ABCMeta): _REQUEST_TIMEOUT: Optional[float] = None # seconds - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__(config, global_config, parent, methods) # Make sure that the value we allow overriding on a per-connection # basis are present in the config so merge_parameters can do its thing. - self.config.setdefault('ssh_port', None) - assert isinstance(self.config['ssh_port'], (int, type(None))) - self.config.setdefault('ssh_username', None) - assert isinstance(self.config['ssh_username'], (str, type(None))) - self.config.setdefault('ssh_priv_key_path', None) - assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) + self.config.setdefault("ssh_port", None) + assert isinstance(self.config["ssh_port"], (int, type(None))) + self.config.setdefault("ssh_username", None) + assert isinstance(self.config["ssh_username"], (str, type(None))) + self.config.setdefault("ssh_priv_key_path", None) + assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) # None can be used to disable the request timeout. self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) @@ -250,24 +273,24 @@ def __init__(self, # In general scripted commands shouldn't need a pty and having one # available can confuse some commands, though we may need to make # this configurable in the future. - 'request_pty': False, + "request_pty": False, # By default disable known_hosts checking (since most VMs expected to be dynamically created). - 'known_hosts': None, + "known_hosts": None, } - if 'ssh_known_hosts_file' in self.config: - self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) - if isinstance(self._connect_params['known_hosts'], str): - known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) + if "ssh_known_hosts_file" in self.config: + self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) + if isinstance(self._connect_params["known_hosts"], str): + known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) if not os.path.exists(known_hosts_file): raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") - self._connect_params['known_hosts'] = known_hosts_file - if self._connect_params['known_hosts'] is None: + self._connect_params["known_hosts"] = known_hosts_file + if self._connect_params["known_hosts"] is None: _LOG.info("%s known_hosts checking is disabled per config.", self) - if 'ssh_keepalive_interval' in self.config: - keepalive_internal = self.config.get('ssh_keepalive_interval') - self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) + if "ssh_keepalive_interval" in self.config: + keepalive_internal = self.config.get("ssh_keepalive_interval") + self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) def _enter_context(self) -> "SshService": # Start the background thread if it's not already running. @@ -277,9 +300,12 @@ def _enter_context(self) -> "SshService": super()._enter_context() return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: # Stop the background thread if it's not needed anymore and potentially # cleanup the cache as well. assert self._in_context @@ -334,24 +360,26 @@ def _get_connect_params(self, params: dict) -> dict: # Start with the base config params. connect_params = self._connect_params.copy() - connect_params['host'] = params['ssh_hostname'] # required + connect_params["host"] = params["ssh_hostname"] # required - if params.get('ssh_port'): - connect_params['port'] = int(params.pop('ssh_port')) - elif self.config['ssh_port']: - connect_params['port'] = int(self.config['ssh_port']) + if params.get("ssh_port"): + connect_params["port"] = int(params.pop("ssh_port")) + elif self.config["ssh_port"]: + connect_params["port"] = int(self.config["ssh_port"]) - if 'ssh_username' in params: - connect_params['username'] = str(params.pop('ssh_username')) - elif self.config['ssh_username']: - connect_params['username'] = str(self.config['ssh_username']) + if "ssh_username" in params: + connect_params["username"] = str(params.pop("ssh_username")) + elif self.config["ssh_username"]: + connect_params["username"] = str(self.config["ssh_username"]) - priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) + priv_key_file: Optional[str] = params.get( + "ssh_priv_key_path", self.config["ssh_priv_key_path"] + ) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) if not os.path.exists(priv_key_file): raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") - connect_params['client_keys'] = [priv_key_file] + connect_params["client_keys"] = [priv_key_file] return connect_params @@ -370,4 +398,6 @@ async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnectio The connection and client objects. """ assert self._in_context - return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) + return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( + self._get_connect_params(params) + ) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 725d0c3306..02bb06e755 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -18,12 +18,12 @@ from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec __all__ = [ - 'SupportsAuth', - 'SupportsConfigLoading', - 'SupportsFileShareOps', - 'SupportsHostProvisioning', - 'SupportsLocalExec', - 'SupportsNetworkProvisioning', - 'SupportsRemoteConfig', - 'SupportsRemoteExec', + "SupportsAuth", + "SupportsConfigLoading", + "SupportsFileShareOps", + "SupportsHostProvisioning", + "SupportsLocalExec", + "SupportsNetworkProvisioning", + "SupportsRemoteConfig", + "SupportsRemoteExec", ] diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 05853da0a9..b09788476f 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -34,8 +34,7 @@ class SupportsConfigLoading(Protocol): Protocol interface for helper functions to lookup and load configs. """ - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -53,7 +52,9 @@ def resolve_path(self, file_path: str, An actual path to the config or script. """ - def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: + def load_config( + self, json_file_name: str, schema_type: Optional[ConfigSchema] + ) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -72,12 +73,14 @@ def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) Free-format dictionary that contains the configuration. """ - def build_environment(self, # pylint: disable=too-many-arguments - config: dict, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> "Environment": + def build_environment( + self, # pylint: disable=too-many-arguments + config: dict, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -107,12 +110,13 @@ def build_environment(self, # pylint: disable=too-many-arguments """ def load_environment_list( # pylint: disable=too-many-arguments - self, - json_file_name: str, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> List["Environment"]: + self, + json_file_name: str, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> List["Environment"]: """ Load and build a list of environments from the config file. @@ -137,9 +141,12 @@ def load_environment_list( # pylint: disable=too-many-arguments A list of new benchmarking environments. """ - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Read the configuration files and bundle all service methods from those configs into a single Service object. diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 87ec9e49da..8252dc17ed 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -15,7 +15,9 @@ class SupportsFileShareOps(Protocol): Protocol interface for file share operations. """ - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -33,7 +35,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index c4c5f01ddc..126966c713 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -32,9 +32,12 @@ class SupportsLocalExec(Protocol): Used in LocalEnv and provided by LocalExecService. """ - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, TunableValue]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -55,7 +58,9 @@ def local_exec(self, script_lines: Iterable[str], A 3-tuple of return code, stdout, and stderr of the script process. """ - def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index fb753aa21c..50b24cc4b8 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -56,7 +56,9 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ - def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]: + def deprovision_network( + self, params: dict, ignore_errors: bool = True + ) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index c653e10c2b..f93de1eab1 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -18,8 +18,7 @@ class SupportsRemoteConfig(Protocol): Protocol interface for configuring cloud services. """ - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple["Status", dict]: + def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index 096cb3c675..f6ca57912a 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -20,8 +20,9 @@ class SupportsRemoteExec(Protocol): scripts on a remote host OS. """ - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Run a command on remote host OS. diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index 9ae5c80f36..0812270747 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.storage.storage_factory import from_config __all__ = [ - 'Storage', - 'from_config', + "Storage", + "from_config", ] diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index ce07e44e2b..47581f0725 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -32,12 +32,15 @@ class ExperimentData(metaclass=ABCMeta): RESULT_COLUMN_PREFIX = "result." CONFIG_COLUMN_PREFIX = "config." - def __init__(self, *, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): self._experiment_id = experiment_id self._description = description self._root_env_config = root_env_config @@ -142,9 +145,9 @@ def default_tunable_config_id(self) -> Optional[int]: trials_items = sorted(self.trials.items()) if not trials_items: return None - for (_trial_id, trial) in trials_items: + for _trial_id, trial in trials_items: # Take the first config id marked as "defaults" when it was instantiated. - if strtobool(str(trial.metadata_dict.get('is_defaults', False))): + if strtobool(str(trial.metadata_dict.get("is_defaults", False))): return trial.tunable_config_id # Fallback (min trial_id) return trials_items[0][1].tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 2165fa706f..b7df86a4b7 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -30,10 +30,12 @@ class Storage(metaclass=ABCMeta): and storage systems (e.g., SQLite or MLFLow). """ - def __init__(self, - config: Dict[str, Any], - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + config: Dict[str, Any], + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new storage object. @@ -74,13 +76,16 @@ def experiments(self) -> Dict[str, ExperimentData]: """ @abstractmethod - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment': + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> "Storage.Experiment": """ Create a new experiment in the storage. @@ -116,23 +121,27 @@ class Experiment(metaclass=ABCMeta): This class is instantiated in the `Storage.experiment()` method. """ - def __init__(self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) + (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( + root_env_config + ) self._description = description self._opt_targets = opt_targets self._in_context = False - def __enter__(self) -> 'Storage.Experiment': + def __enter__(self) -> "Storage.Experiment": """ Enter the context of the experiment. @@ -144,9 +153,12 @@ def __enter__(self) -> 'Storage.Experiment': self._in_context = True return self - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: """ End the context of the experiment. @@ -157,8 +169,9 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], _LOG.debug("Finishing experiment: %s", self) else: assert exc_type and exc_val - _LOG.warning("Finishing experiment: %s", self, - exc_info=(exc_type, exc_val, exc_tb)) + _LOG.warning( + "Finishing experiment: %s", self, exc_info=(exc_type, exc_val, exc_tb) + ) assert self._in_context self._teardown(is_ok) self._in_context = False @@ -248,8 +261,10 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: """ @abstractmethod - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. @@ -269,7 +284,9 @@ def load(self, last_trial_id: int = -1, """ @abstractmethod - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: + def pending_trials( + self, timestamp: datetime, *, running: bool + ) -> Iterator["Storage.Trial"]: """ Return an iterator over the pending trials that are scheduled to run on or before the specified timestamp. @@ -289,8 +306,12 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Sto """ @abstractmethod - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "Storage.Trial": """ Create a new experiment run in the storage. @@ -317,10 +338,16 @@ class Trial(metaclass=ABCMeta): This class is instantiated in the `Storage.Experiment.trial()` method. """ - def __init__(self, *, - tunables: TunableGroups, experiment_id: str, trial_id: int, - tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): self._tunables = tunables self._experiment_id = experiment_id self._trial_id = trial_id @@ -378,9 +405,9 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An return config @abstractmethod - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. @@ -404,14 +431,18 @@ def update(self, status: Status, timestamp: datetime, assert metrics is not None opt_targets = set(self._opt_targets.keys()) if not opt_targets.issubset(metrics.keys()): - _LOG.warning("Trial %s :: opt.targets missing: %s", - self, opt_targets.difference(metrics.keys())) + _LOG.warning( + "Trial %s :: opt.targets missing: %s", + self, + opt_targets.difference(metrics.keys()), + ) # raise ValueError() return metrics @abstractmethod - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + ) -> None: """ Save the experiment's telemetry data and intermediate status. diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index b3b2bed86a..6ad397d753 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -31,13 +31,16 @@ class TrialData(metaclass=ABCMeta): of tunable parameters). """ - def __init__(self, *, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0dce110b1b..0c9adce22d 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -21,8 +21,7 @@ class TunableConfigData(metaclass=ABCMeta): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - tunable_config_id: int): + def __init__(self, *, tunable_config_id: int): self._tunable_config_id = tunable_config_id def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 18c50035a9..6ad0fe185a 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -27,10 +27,13 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): self._experiment_id = experiment_id self._tunable_config_id = tunable_config_id # can be lazily initialized as necessary: @@ -77,7 +80,10 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id + return ( + self._tunable_config_id == other._tunable_config_id + and self._experiment_id == other._experiment_id + ) @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index 735e21bcaf..cf09b9aa5a 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -8,5 +8,5 @@ from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ - 'SqlStorage', + "SqlStorage", ] diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index c7ee73a3bc..bdeb6d8bf3 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -18,10 +18,8 @@ def get_trials( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: + engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None +) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -30,13 +28,18 @@ def get_trials( from mlos_bench.storage.sql.trial_data import ( TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import ) + with engine.connect() as conn: # Build up sql a statement for fetching trials. - stmt = schema.trial.select().where( - schema.trial.c.exp_id == experiment_id, - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + stmt = ( + schema.trial.select() + .where( + schema.trial.c.exp_id == experiment_id, + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -60,10 +63,8 @@ def get_trials( def get_results_df( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> pandas.DataFrame: + engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None +) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -72,15 +73,22 @@ def get_results_df( # pylint: disable=too-many-locals with engine.connect() as conn: # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. - tunable_config_group_id_stmt = schema.trial.select().with_only_columns( - schema.trial.c.exp_id, - schema.trial.c.config_id, - func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'), - ).where( - schema.trial.c.exp_id == experiment_id, - ).group_by( - schema.trial.c.exp_id, - schema.trial.c.config_id, + tunable_config_group_id_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.exp_id, + schema.trial.c.config_id, + func.min(schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .group_by( + schema.trial.c.exp_id, + schema.trial.c.config_id, + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -90,18 +98,22 @@ def get_results_df( tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() # Get each trial's metadata. - cur_trials_stmt = select( - schema.trial, - tunable_config_trial_group_id_subquery, - ).where( - schema.trial.c.exp_id == experiment_id, - and_( - tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, - tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, - ), - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + cur_trials_stmt = ( + select( + schema.trial, + tunable_config_trial_group_id_subquery, + ) + .where( + schema.trial.c.exp_id == experiment_id, + and_( + tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, + tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, + ), + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -110,39 +122,48 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [( - row.trial_id, - utcify_timestamp(row.ts_start, origin="utc"), - utcify_nullable_timestamp(row.ts_end, origin="utc"), - row.config_id, - row.tunable_config_trial_group_id, - row.status, - ) for row in cur_trials.fetchall()], + [ + ( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) + for row in cur_trials.fetchall() + ], columns=[ - 'trial_id', - 'ts_start', - 'ts_end', - 'tunable_config_id', - 'tunable_config_trial_group_id', - 'status', - ] + "trial_id", + "ts_start", + "ts_end", + "tunable_config_id", + "tunable_config_trial_group_id", + "status", + ], ) # Get each trial's config in wide format. - configs_stmt = schema.trial.select().with_only_columns( - schema.trial.c.trial_id, - schema.trial.c.config_id, - schema.config_param.c.param_id, - schema.config_param.c.param_value, - ).where( - schema.trial.c.exp_id == experiment_id, - ).join( - schema.config_param, - schema.config_param.c.config_id == schema.trial.c.config_id, - isouter=True - ).order_by( - schema.trial.c.trial_id, - schema.config_param.c.param_id, + configs_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.trial_id, + schema.trial.c.config_id, + schema.config_param.c.param_id, + schema.config_param.c.param_value, + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .join( + schema.config_param, + schema.config_param.c.config_id == schema.trial.c.config_id, + isouter=True, + ) + .order_by( + schema.trial.c.trial_id, + schema.config_param.c.param_id, + ) ) if tunable_config_id is not None: configs_stmt = configs_stmt.where( @@ -150,41 +171,67 @@ def get_results_df( ) configs = conn.execute(configs_stmt) configs_df = pandas.DataFrame( - [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value) - for row in configs.fetchall()], - columns=['trial_id', 'tunable_config_id', 'param', 'value'] + [ + ( + row.trial_id, + row.config_id, + ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, + row.param_value, + ) + for row in configs.fetchall() + ], + columns=["trial_id", "tunable_config_id", "param", "value"], ).pivot( - index=["trial_id", "tunable_config_id"], columns="param", values="value", + index=["trial_id", "tunable_config_id"], + columns="param", + values="value", ) - configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp) + configs_df = configs_df.apply(pandas.to_numeric, errors="coerce").fillna(configs_df) # type: ignore[assignment] # (fp) # Get each trial's results in wide format. - results_stmt = schema.trial_result.select().with_only_columns( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - schema.trial_result.c.metric_value, - ).where( - schema.trial_result.c.exp_id == experiment_id, - ).order_by( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, + results_stmt = ( + schema.trial_result.select() + .with_only_columns( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + schema.trial_result.c.metric_value, + ) + .where( + schema.trial_result.c.exp_id == experiment_id, + ) + .order_by( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + ) ) if tunable_config_id is not None: - results_stmt = results_stmt.join(schema.trial, and_( - schema.trial.c.exp_id == schema.trial_result.c.exp_id, - schema.trial.c.trial_id == schema.trial_result.c.trial_id, - schema.trial.c.config_id == tunable_config_id, - )) + results_stmt = results_stmt.join( + schema.trial, + and_( + schema.trial.c.exp_id == schema.trial_result.c.exp_id, + schema.trial.c.trial_id == schema.trial_result.c.trial_id, + schema.trial.c.config_id == tunable_config_id, + ), + ) results = conn.execute(results_stmt) results_df = pandas.DataFrame( - [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value) - for row in results.fetchall()], - columns=['trial_id', 'metric', 'value'] + [ + ( + row.trial_id, + ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, + row.metric_value, + ) + for row in results.fetchall() + ], + columns=["trial_id", "metric", "value"], ).pivot( - index="trial_id", columns="metric", values="value", + index="trial_id", + columns="metric", + values="value", ) - results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp) + results_df = results_df.apply(pandas.to_numeric, errors="coerce").fillna(results_df) # type: ignore[assignment] # (fp) # Concat the trials, configs, and results. - return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \ - .merge(results_df, on="trial_id", how="left") + return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( + results_df, on="trial_id", how="left" + ) diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 58ee3dddb5..e6322c7ade 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -29,15 +29,18 @@ class Experiment(Storage.Experiment): Logic for retrieving and storing the results of a single experiment. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -55,18 +58,22 @@ def _setup(self) -> None: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable exp_info = conn.execute( - self._schema.experiment.select().with_only_columns( + self._schema.experiment.select() + .with_only_columns( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, func.max(self._schema.trial.c.trial_id).label("trial_id"), - ).join( + ) + .join( self._schema.trial, self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, - isouter=True - ).where( + isouter=True, + ) + .where( self._schema.experiment.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, @@ -75,33 +82,47 @@ def _setup(self) -> None: if exp_info is None: _LOG.info("Start new experiment: %s", self._experiment_id) # It's a new experiment: create a record for it in the database. - conn.execute(self._schema.experiment.insert().values( - exp_id=self._experiment_id, - description=self._description, - git_repo=self._git_repo, - git_commit=self._git_commit, - root_env_config=self._root_env_config, - )) - conn.execute(self._schema.objectives.insert().values([ - { - "exp_id": self._experiment_id, - "optimization_target": opt_target, - "optimization_direction": opt_dir, - } - for (opt_target, opt_dir) in self.opt_targets.items() - ])) + conn.execute( + self._schema.experiment.insert().values( + exp_id=self._experiment_id, + description=self._description, + git_repo=self._git_repo, + git_commit=self._git_commit, + root_env_config=self._root_env_config, + ) + ) + conn.execute( + self._schema.objectives.insert().values( + [ + { + "exp_id": self._experiment_id, + "optimization_target": opt_target, + "optimization_direction": opt_dir, + } + for (opt_target, opt_dir) in self.opt_targets.items() + ] + ) + ) else: if exp_info.trial_id is not None: self._trial_id = exp_info.trial_id + 1 - _LOG.info("Continue experiment: %s last trial: %s resume from: %d", - self._experiment_id, exp_info.trial_id, self._trial_id) + _LOG.info( + "Continue experiment: %s last trial: %s resume from: %d", + self._experiment_id, + exp_info.trial_id, + self._trial_id, + ) # TODO: Sanity check that certain critical configs (e.g., # objectives) haven't changed to be incompatible such that a new # experiment should be started (possibly by prewarming with the # previous one). if exp_info.git_commit != self._git_commit: - _LOG.warning("Experiment %s git expected: %s %s", - self, exp_info.git_repo, exp_info.git_commit) + _LOG.warning( + "Experiment %s git expected: %s %s", + self, + exp_info.git_repo, + exp_info.git_commit, + ) def merge(self, experiment_ids: List[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) @@ -114,33 +135,42 @@ def load_tunable_config(self, config_id: int) -> Dict[str, Any]: def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) ) # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. - return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) - for row in cur_telemetry.fetchall()] + return [ + (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall() + ] - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: with self._engine.connect() as conn: cur_trials = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.trial_id, self._schema.trial.c.config_id, self._schema.trial.c.status, - ).where( + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, - self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), - ).order_by( + self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]), + ) + .order_by( self._schema.trial.c.trial_id.asc(), ) ) @@ -154,12 +184,21 @@ def load(self, last_trial_id: int = -1, stat = Status[trial.status] status.append(stat) trial_ids.append(trial.trial_id) - configs.append(self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id)) + configs.append( + self._get_key_val( + conn, self._schema.config_param, "param", config_id=trial.config_id + ) + ) if stat.is_succeeded(): - scores.append(self._get_key_val( - conn, self._schema.trial_result, "metric", - exp_id=self._experiment_id, trial_id=trial.trial_id)) + scores.append( + self._get_key_val( + conn, + self._schema.trial_result, + "metric", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) + ) else: scores.append(None) @@ -175,49 +214,59 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D select( column(f"{field}_id"), column(f"{field}_value"), - ).select_from(table).where( - *[column(key) == val for (key, val) in kwargs.items()] ) + .select_from(table) + .where(*[column(key) == val for (key, val) in kwargs.items()]) ) # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. - return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access + return dict( + row._tuple() for row in cur_result.fetchall() + ) # pylint: disable=protected-access @staticmethod - def _save_params(conn: Connection, table: Table, - params: Dict[str, Any], **kwargs: Any) -> None: + def _save_params( + conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any + ) -> None: if not params: return - conn.execute(table.insert(), [ - { - **kwargs, - "param_id": key, - "param_value": nullable(str, val) - } - for (key, val) in params.items() - ]) + conn.execute( + table.insert(), + [ + {**kwargs, "param_id": key, "param_value": nullable(str, val)} + for (key, val) in params.items() + ], + ) def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) if running: - pending_status = ['PENDING', 'READY', 'RUNNING'] + pending_status = ["PENDING", "READY", "RUNNING"] else: - pending_status = ['PENDING'] + pending_status = ["PENDING"] with self._engine.connect() as conn: - cur_trials = conn.execute(self._schema.trial.select().where( - self._schema.trial.c.exp_id == self._experiment_id, - (self._schema.trial.c.ts_start.is_(None) | - (self._schema.trial.c.ts_start <= timestamp)), - self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), - )) + cur_trials = conn.execute( + self._schema.trial.select().where( + self._schema.trial.c.exp_id == self._experiment_id, + ( + self._schema.trial.c.ts_start.is_(None) + | (self._schema.trial.c.ts_start <= timestamp) + ), + self._schema.trial.c.ts_end.is_(None), + self._schema.trial.c.status.in_(pending_status), + ) + ) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", - config_id=trial.config_id) + conn, self._schema.config_param, "param", config_id=trial.config_id + ) config = self._get_key_val( - conn, self._schema.trial_param, "param", - exp_id=self._experiment_id, trial_id=trial.trial_id) + conn, + self._schema.trial_param, + "param", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) yield Trial( engine=self._engine, schema=self._schema, @@ -235,42 +284,55 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: Get the config ID for the given tunables. If the config does not exist, create a new record for it. """ - config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() - cur_config = conn.execute(self._schema.config.select().where( - self._schema.config.c.config_hash == config_hash - )).fetchone() + config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() + cur_config = conn.execute( + self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) + ).fetchone() if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute(self._schema.config.insert().values( - config_hash=config_hash)).inserted_primary_key[0] + config_id: int = conn.execute( + self._schema.config.insert().values(config_hash=config_hash) + ).inserted_primary_key[0] self._save_params( - conn, self._schema.config_param, + conn, + self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, - config_id=config_id) + config_id=config_id, + ) return config_id - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> Storage.Trial: + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Storage.Trial: ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) - conn.execute(self._schema.trial.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - config_id=config_id, - ts_start=ts_start, - status='PENDING', - )) + conn.execute( + self._schema.trial.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + config_id=config_id, + ts_start=ts_start, + status="PENDING", + ) + ) # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: self._save_params( - conn, self._schema.trial_param, config, - exp_id=self._experiment_id, trial_id=self._trial_id) + conn, + self._schema.trial_param, + config, + exp_id=self._experiment_id, + trial_id=self._trial_id, + ) trial = Trial( engine=self._engine, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index eaa6e1041f..f299bcff68 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -35,14 +35,17 @@ class ExperimentSqlData(ExperimentData): scripts and mlos_bench configuration files. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): super().__init__( experiment_id=experiment_id, description=description, @@ -57,9 +60,11 @@ def __init__(self, *, def objectives(self) -> Dict[str, Literal["min", "max"]]: with self._engine.connect() as conn: objectives_db_data = conn.execute( - self._schema.objectives.select().where( + self._schema.objectives.select() + .where( self._schema.objectives.c.exp_id == self._experiment_id, - ).order_by( + ) + .order_by( self._schema.objectives.c.weight.desc(), self._schema.objectives.c.optimization_target.asc(), ) @@ -80,13 +85,17 @@ def trials(self) -> Dict[int, TrialData]: def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: with self._engine.connect() as conn: tunable_config_trial_groups = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.config_id, - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), # pylint: disable=not-callable + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -106,11 +115,14 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: def tunable_configs(self) -> Dict[int, TunableConfigData]: with self._engine.connect() as conn: tunable_configs = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -139,20 +151,28 @@ def default_tunable_config_id(self) -> Optional[int]: """ with self._engine.connect() as conn: query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial_param.select().with_only_columns( - func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - "first_trial_id_with_defaults"), - ).where( + self._schema.trial_param.select() + .with_only_columns( + func.min(self._schema.trial_param.c.trial_id) + .cast(Integer) + .label("first_trial_id_with_defaults"), # pylint: disable=not-callable + ) + .where( self._schema.trial_param.c.exp_id == self._experiment_id, self._schema.trial_param.c.param_id == "is_defaults", - func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), - ).scalar_subquery() - ) + func.lower(self._schema.trial_param.c.param_value, type_=String).in_( + ["1", "true"] + ), + ) + .scalar_subquery() + ), ) ) min_default_trial_row = query_results.fetchone() @@ -161,17 +181,24 @@ def default_tunable_config_id(self) -> Optional[int]: return min_default_trial_row._tuple()[0] # fallback logic - assume minimum trial_id for experiment query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("first_trial_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).scalar_subquery() - ) + ) + .scalar_subquery() + ), ) ) min_trial_row = query_results.fetchone() diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 9a1eca2744..65f0e35694 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -80,7 +80,6 @@ def __init__(self, engine: Engine): Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), - PrimaryKeyConstraint("exp_id"), ) @@ -95,20 +94,25 @@ def __init__(self, engine: Engine): # Will need to adjust the insert and return values to support this # eventually. Column("weight", Float, nullable=True), - PrimaryKeyConstraint("exp_id", "optimization_target"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ) # A workaround for SQLAlchemy issue with autoincrement in DuckDB: if engine.dialect.name == "duckdb": - seq_config_id = Sequence('seq_config_id') - col_config_id = Column("config_id", Integer, seq_config_id, - server_default=seq_config_id.next_value(), - nullable=False, primary_key=True) + seq_config_id = Sequence("seq_config_id") + col_config_id = Column( + "config_id", + Integer, + seq_config_id, + server_default=seq_config_id.next_value(), + nullable=False, + primary_key=True, + ) else: - col_config_id = Column("config_id", Integer, nullable=False, - primary_key=True, autoincrement=True) + col_config_id = Column( + "config_id", Integer, nullable=False, primary_key=True, autoincrement=True + ) self.config = Table( "config", @@ -127,7 +131,6 @@ def __init__(self, engine: Engine): Column("ts_end", DateTime), # Should match the text IDs of `mlos_bench.environments.Status` enum: Column("status", String(self._STATUS_LEN), nullable=False), - PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -141,7 +144,6 @@ def __init__(self, engine: Engine): Column("config_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -155,10 +157,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_status = Table( @@ -168,10 +170,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), - UniqueConstraint("exp_id", "trial_id", "ts"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_result = Table( @@ -181,10 +183,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_telemetry = Table( @@ -195,15 +197,15 @@ def __init__(self, engine: Engine): Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) _LOG.debug("Schema: %s", self._meta) - def create(self) -> 'DbSchema': + def create(self) -> "DbSchema": """ Create the DB schema. """ diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index bde38575bd..dec1385cf2 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -27,10 +27,9 @@ class SqlStorage(Storage): An implementation of the Storage interface using SQLAlchemy backend. """ - def __init__(self, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, config: dict, global_config: Optional[dict] = None, service: Optional[Service] = None + ): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) self._log_sql = self._config.pop("log_sql", False) @@ -47,7 +46,7 @@ def __init__(self, @property def _schema(self) -> DbSchema: """Lazily create schema upon first access.""" - if not hasattr(self, '_db_schema'): + if not hasattr(self, "_db_schema"): self._db_schema = DbSchema(self._engine).create() if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("DDL statements:\n%s", self._schema) @@ -56,13 +55,16 @@ def _schema(self) -> DbSchema: def __repr__(self) -> str: return self._repr - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment: + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> Storage.Experiment: return Experiment( engine=self._engine, schema=self._schema, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 7ac7958845..189cc68ebd 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -27,15 +27,18 @@ class Trial(Storage.Trial): Store the results of a single run of the experiment in SQL database. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - config_id: int, - opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -47,9 +50,9 @@ def __init__(self, *, self._engine = engine self._schema = schema - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) @@ -59,13 +62,16 @@ def update(self, status: Status, timestamp: datetime, if status.is_completed(): # Final update of the status and ts_end: cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] + ), + ) + .values( status=status.name, ts_end=timestamp, ) @@ -73,29 +79,37 @@ def update(self, status: Status, timestamp: datetime, if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." + - f" ({cur_status.rowcount} rows)") + f"Failed to update the status of the trial {self} to {status}." + + f" ({cur_status.rowcount} rows)" + ) if metrics: - conn.execute(self._schema.trial_result.insert().values([ - { - "exp_id": self._experiment_id, - "trial_id": self._trial_id, - "metric_id": key, - "metric_value": nullable(str, val), - } - for (key, val) in metrics.items() - ])) + conn.execute( + self._schema.trial_result.insert().values( + [ + { + "exp_id": self._experiment_id, + "trial_id": self._trial_id, + "metric_id": key, + "metric_value": nullable(str, val), + } + for (key, val) in metrics.items() + ] + ) + ) else: # Update of the status and ts_start when starting the trial: assert metrics is None, f"Unexpected metrics for status: {status}" cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] + ), + ) + .values( status=status.name, ts_start=timestamp, ) @@ -108,8 +122,9 @@ def update(self, status: Status, timestamp: datetime, raise return metrics - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + ) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") @@ -120,16 +135,18 @@ def update_telemetry(self, status: Status, timestamp: datetime, # See Also: comments in with self._engine.begin() as conn: self._update_status(conn, status, timestamp) - for (metric_ts, key, val) in metrics: + for metric_ts, key, val in metrics: with self._engine.begin() as conn: try: - conn.execute(self._schema.trial_telemetry.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=metric_ts, - metric_id=key, - metric_value=nullable(str, val), - )) + conn.execute( + self._schema.trial_telemetry.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=metric_ts, + metric_id=key, + metric_value=nullable(str, val), + ) + ) except IntegrityError as ex: _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) @@ -141,12 +158,15 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") try: - conn.execute(self._schema.trial_status.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=timestamp, - status=status.name, - )) + conn.execute( + self._schema.trial_status.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=timestamp, + status=status.name, + ) + ) except IntegrityError as ex: - _LOG.warning("Status with that timestamp already exists: %s %s :: %s", - self, timestamp, ex) + _LOG.warning( + "Status with that timestamp already exists: %s %s :: %s", self, timestamp, ex + ) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 5a6f8a5ee8..c5138f91af 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -29,15 +29,18 @@ class TrialSqlData(TrialData): An interface to access the trial data stored in the SQL DB. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - trial_id: int, - config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + trial_id: int, + config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): super().__init__( experiment_id=experiment_id, trial_id=trial_id, @@ -56,8 +59,9 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ - return TunableConfigSqlData(engine=self._engine, schema=self._schema, - tunable_config_id=self._tunable_config_id) + return TunableConfigSqlData( + engine=self._engine, schema=self._schema, tunable_config_id=self._tunable_config_id + ) @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @@ -68,9 +72,13 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": from mlos_bench.storage.sql.tunable_config_trial_group_data import ( TunableConfigTrialGroupSqlData, ) - return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, - experiment_id=self._experiment_id, - tunable_config_id=self._tunable_config_id) + + return TunableConfigTrialGroupSqlData( + engine=self._engine, + schema=self._schema, + experiment_id=self._experiment_id, + tunable_config_id=self._tunable_config_id, + ) @property def results_df(self) -> pandas.DataFrame: @@ -79,16 +87,19 @@ def results_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_results = conn.execute( - self._schema.trial_result.select().where( + self._schema.trial_result.select() + .where( self._schema.trial_result.c.exp_id == self._experiment_id, - self._schema.trial_result.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_result.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_result.c.metric_id, ) ) return pandas.DataFrame( [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], - columns=['metric', 'value']) + columns=["metric", "value"], + ) @property def telemetry_df(self) -> pandas.DataFrame: @@ -97,10 +108,12 @@ def telemetry_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) @@ -108,8 +121,12 @@ def telemetry_df(self) -> pandas.DataFrame: # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], - columns=['ts', 'metric', 'value']) + [ + (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall() + ], + columns=["ts", "metric", "value"], + ) @property def metadata_df(self) -> pandas.DataFrame: @@ -120,13 +137,16 @@ def metadata_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_params = conn.execute( - self._schema.trial_param.select().where( + self._schema.trial_param.select() + .where( self._schema.trial_param.c.exp_id == self._experiment_id, - self._schema.trial_param.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_param.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_params.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index e484979790..2441f70b9c 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -20,10 +20,7 @@ class TunableConfigSqlData(TunableConfigData): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunable_config_id: int): + def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): super().__init__(tunable_config_id=tunable_config_id) self._engine = engine self._schema = schema @@ -32,12 +29,13 @@ def __init__(self, *, def config_df(self) -> pandas.DataFrame: with self._engine.connect() as conn: cur_config = conn.execute( - self._schema.config_param.select().where( - self._schema.config_param.c.config_id == self._tunable_config_id - ).order_by( + self._schema.config_param.select() + .where(self._schema.config_param.c.config_id == self._tunable_config_id) + .order_by( self._schema.config_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_config.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index eb389a5940..3520e77c60 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -33,12 +33,15 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): super().__init__( experiment_id=experiment_id, tunable_config_id=tunable_config_id, @@ -53,20 +56,26 @@ def _get_tunable_config_trial_group_id(self) -> int: """ with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), # pylint: disable=not-callable + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.config_id == self._tunable_config_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[ + 0 + ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy @property def tunable_config(self) -> TunableConfigData: @@ -86,8 +95,12 @@ def trials(self) -> Dict[int, "TrialData"]: trials : Dict[int, TrialData] A dictionary of the trials' data, keyed by trial id. """ - return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_trials( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) @property def results_df(self) -> pandas.DataFrame: - return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_results_df( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 220f3d812c..22e629fc82 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -13,9 +13,9 @@ from mlos_bench.storage.base_storage import Storage -def from_config(config_file: str, - global_configs: Optional[List[str]] = None, - **kwargs: Any) -> Storage: +def from_config( + config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any +) -> Storage: """ Create a new storage object from JSON5 config file. @@ -36,7 +36,7 @@ def from_config(config_file: str, config_path: List[str] = kwargs.get("config_path", []) config_loader = ConfigPersistenceService({"config_path": config_path}) global_config = {} - for fname in (global_configs or []): + for fname in global_configs or []: config = config_loader.load_config(fname, ConfigSchema.GLOBALS) global_config.update(config) config_path += config.get("config_path", []) diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index a4610da8de..d16dc81b79 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -25,16 +25,18 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where 'parameter' is a string and 'value' is some TunableValue or None. """ - if dataframe.columns.tolist() == ['metric', 'value']: + if dataframe.columns.tolist() == ["metric", "value"]: dataframe = dataframe.copy() - dataframe.rename(columns={'metric': 'parameter'}, inplace=True) - assert dataframe.columns.tolist() == ['parameter', 'value'] + dataframe.rename(columns={"metric": "parameter"}, inplace=True) + assert dataframe.columns.tolist() == ["parameter", "value"] data = {} - for _, row in dataframe.astype('O').iterrows(): - if not isinstance(row['value'], TunableValueTypeTuple): + for _, row in dataframe.astype("O").iterrows(): + if not isinstance(row["value"], TunableValueTypeTuple): raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}") - assert isinstance(row['parameter'], str) - if row['parameter'] in data: + assert isinstance(row["parameter"], str) + if row["parameter"] in data: raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe") - data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value'] + data[row["parameter"]] = ( + try_parse_val(row["value"]) if isinstance(row["value"], str) else row["value"] + ) return data diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 26aa142441..3b8c23a70c 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -29,26 +29,34 @@ None, ] ZONE_INFO: List[Optional[tzinfo]] = [ - nullable(pytz.timezone, zone_name) - for zone_name in ZONE_NAMES + nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES ] # A decorator for tests that require docker. # Use with @requires_docker above a test_...() function. -DOCKER = shutil.which('docker') +DOCKER = shutil.which("docker") if DOCKER: - cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) + cmd = run( + "docker builder inspect default || docker buildx inspect default", + shell=True, + check=False, + capture_output=True, + ) stdout = cmd.stdout.decode() - if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): + if cmd.returncode != 0 or not any( + line for line in stdout.splitlines() if "Platform" in line and "linux" in line + ): debug("Docker is available but missing support for targeting linux platform.") DOCKER = None -requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') +requires_docker = pytest.mark.skipif( + not DOCKER, reason="Docker with Linux support is not available on this system." +) # A decorator for tests that require ssh. # Use with @requires_ssh above a test_...() function. -SSH = shutil.which('ssh') -requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') +SSH = shutil.which("ssh") +requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -131,8 +139,14 @@ def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) - if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: - warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") + if ( + len(dirs_cmp.left_only) > 0 + or len(dirs_cmp.right_only) > 0 + or len(dirs_cmp.funny_files) > 0 + ): + warning( + f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" + ) return False (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) if len(mismatch) > 0 or len(errors) > 0: diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 4d728b4037..61fb063a52 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -21,9 +21,11 @@ BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/") -def locate_config_examples(root_dir: str, - config_examples_dir: str, - examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: +def locate_config_examples( + root_dir: str, + config_examples_dir: str, + examples_filter: Optional[Callable[[List[str]], List[str]]] = None, +) -> List[str]: """Locates all config examples in the given directory. Parameters diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index e1e26d7d8b..7c1d55ef9f 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -43,7 +43,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -51,7 +53,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover +def test_load_cli_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -61,7 +65,7 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic assert isinstance(config_paths, list) config_paths.reverse() for path in config_paths: - config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access + config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access # Foreach arg that references another file, see if we can at least load that too. args_to_skip = { @@ -98,7 +102,9 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_cli_config_examples_via_launcher( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) @@ -106,10 +112,12 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Try to load the CLI config by instantiating a launcher. # To do this we need to make sure to give it a few extra paths and globals # to look for for our examples. - cli_args = f"--config {config_path}" + \ - f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \ - f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \ - f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" + cli_args = ( + f"--config {config_path}" + + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" + ) launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split()) assert launcher @@ -120,15 +128,16 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers assert isinstance(config_paths, list) for path in config_paths: # Note: Checks that the order is maintained are handled in launcher_parse_args.py - assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \ - f"Expected {path} to be in {launcher.config_loader.config_paths}" + assert any( + config_path.endswith(path) for config_path in launcher.config_loader.config_paths + ), f"Expected {path} to be in {launcher.config_loader.config_paths}" - if 'experiment_id' in config: - assert launcher.global_config['experiment_id'] == config['experiment_id'] - if 'trial_id' in config: - assert launcher.global_config['trial_id'] == config['trial_id'] + if "experiment_id" in config: + assert launcher.global_config["experiment_id"] == config["experiment_id"] + if "trial_id" in config: + assert launcher.global_config["trial_id"] == config["trial_id"] - expected_log_level = logging.getLevelName(config.get('log_level', "INFO")) + expected_log_level = logging.getLevelName(config.get("log_level", "INFO")) if isinstance(expected_log_level, int): expected_log_level = logging.getLevelName(expected_log_level) current_log_level = logging.getLevelName(logging.root.getEffectiveLevel()) @@ -136,7 +145,7 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # TODO: Check that the log_file handler is set correctly. - expected_teardown = config.get('teardown', True) + expected_teardown = config.get("teardown", True) assert launcher.teardown == expected_teardown # Note: Testing of "globals" processing handled in launcher_parse_args_test.py @@ -145,22 +154,30 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Launcher loaded the expected types as well. assert isinstance(launcher.environment, Environment) - env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT) + env_config = launcher.config_loader.load_config( + config["environment"], ConfigSchema.ENVIRONMENT + ) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: - opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER) + opt_config = launcher.config_loader.load_config( + config["optimizer"], ConfigSchema.OPTIMIZER + ) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: - storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE) + storage_config = launcher.config_loader.load_config( + config["storage"], ConfigSchema.STORAGE + ) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: - scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER) + scheduler_config = launcher.config_loader.load_config( + config["scheduler"], ConfigSchema.SCHEDULER + ) assert check_class_name(launcher.scheduler, scheduler_config["class"]) # TODO: Check that the launcher assigns the tunables values as expected. diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index fdcb3370cf..2c3932a128 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -22,9 +22,11 @@ @pytest.fixture def config_loader_service() -> ConfigPersistenceService: """Config loader service fixture.""" - return ConfigPersistenceService(config={ - "config_path": [ - str(files("mlos_bench.tests.config")), - path_join(str(files("mlos_bench.tests.config")), "globals"), - ] - }) + return ConfigPersistenceService( + config={ + "config_path": [ + str(files("mlos_bench.tests.config")), + path_join(str(files("mlos_bench.tests.config")), "globals"), + ] + } + ) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 42925a0a5d..2369b0c27a 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -27,16 +27,24 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] + configs_to_filter = [ + config_path + for config_path in configs_to_filter + if not config_path.endswith("-tunables.jsonc") + ] return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) for env in envs: @@ -44,11 +52,15 @@ def test_load_environment_config_examples(config_loader_service: ConfigPersisten assert isinstance(env, Environment) -def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: +def load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. - global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) - global_config.setdefault('trial_id', 1) # normally populated by Launcher + global_config = config_loader_service.load_config( + "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS + ) + global_config.setdefault("trial_id", 1) # normally populated by Launcher # Make sure we have the required services for the envs being used. mock_service_configs = [ @@ -60,24 +72,34 @@ def load_environment_config_examples(config_loader_service: ConfigPersistenceSer "services/remote/mock/mock_auth_service.jsonc", ] - tunable_groups = TunableGroups() # base tunable groups that all others get built on + tunable_groups = TunableGroups() # base tunable groups that all others get built on for mock_service_config_path in mock_service_configs: - mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) - config_loader_service.register(config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service).export()) + mock_service_config = config_loader_service.load_config( + mock_service_config_path, ConfigSchema.SERVICE + ) + config_loader_service.register( + config_loader_service.build_service( + config=mock_service_config, parent=config_loader_service + ).export() + ) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service) + config_path, tunable_groups, global_config, service=config_loader_service + ) return envs -composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") +composite_configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" +) assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) -def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_composite_env_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) assert len(envs) == 1 @@ -90,11 +112,15 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist assert child_env.tunable_params is not None checked_child_env_groups = set() - for (child_tunable, child_group) in child_env.tunable_params: + for child_tunable, child_group in child_env.tunable_params: # Lookup that tunable in the composite env. assert child_tunable in composite_env.tunable_params - (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) - assert child_tunable is composite_tunable # Check that the tunables are the same object. + (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( + child_tunable + ) + assert ( + child_tunable is composite_tunable + ) # Check that the tunables are the same object. if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index 4d8c93fdff..fd53d63788 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -29,7 +29,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs + ), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs), ] @@ -37,7 +39,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) -def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_globals_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index 6cb6253dea..c504a6d50f 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -30,12 +30,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_optimizer_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index e4264003e1..6d2cabaa8a 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -34,14 +34,17 @@ def __hash__(self) -> int: # The different type of schema test cases we expect to have. -_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( - SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), - SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), -)} +_SCHEMA_TEST_TYPES = { + x.test_case_type: x + for x in ( + SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), + SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}), + ) +} @dataclass -class SchemaTestCaseInfo(): +class SchemaTestCaseInfo: """ Some basic info about a schema test case. """ @@ -61,15 +64,17 @@ def check_schema_dir_layout(test_cases_root: str) -> None: any extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): - if test_case_dir == 'README.md': + if test_case_dir == "README.md": continue if test_case_dir not in _SCHEMA_TEST_TYPES: raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): - if test_case_subdir == 'README.md': + if test_case_subdir == "README.md": continue if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: - raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") + raise NotImplementedError( + f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}" + ) @dataclass @@ -87,15 +92,21 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: """ Gets a dict of schema test cases from the given root. """ - test_cases = TestCases(by_path={}, - by_type={x: {} for x in _SCHEMA_TEST_TYPES}, - by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) + test_cases = TestCases( + by_path={}, + by_type={x: {} for x in _SCHEMA_TEST_TYPES}, + by_subtype={ + y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes + }, + ) check_schema_dir_layout(test_cases_root) # Note: we sort the test cases so that we can deterministically test them in parallel. - for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): + for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): for test_case_subtype in schema_test_type.test_case_subtypes: - for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): - with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: + for test_case_file in locate_config_examples( + test_cases_root, os.path.join(test_case_type, test_case_subtype) + ): + with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: try: test_case_info = SchemaTestCaseInfo( config=json5.load(test_case_fh), @@ -104,8 +115,12 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: test_case_subtype=test_case_subtype, ) test_cases.by_path[test_case_info.test_case_file] = test_case_info - test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info - test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info + test_cases.by_type[test_case_info.test_case_type][ + test_case_info.test_case_file + ] = test_case_info + test_cases.by_subtype[test_case_info.test_case_subtype][ + test_case_info.test_case_file + ] = test_case_info except Exception as ex: raise RuntimeError("Failed to load test case: " + test_case_file) from ex assert test_cases @@ -117,7 +132,9 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: return test_cases -def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_against_schema( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """ Checks the given test case against the given schema. @@ -142,7 +159,9 @@ def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: C raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") -def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_config_with_extra_param( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """ Checks that the config fails to validate if extra params are present in certain places. """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 5dd1666008..32ea0b9713 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -26,6 +26,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: """ @@ -44,7 +45,9 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the cli config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI + ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index dc3cd40425..1528d8d164 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,17 +33,21 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. } -expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_ENV_CLASSES] +expected_environment_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_ENV_CLASSES +] assert expected_environment_class_names COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ -expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names - if subclass_name != COMPOSITE_ENV_CLASS_NAME] +expected_leaf_environment_class_names = [ + subclass_name + for subclass_name in expected_environment_class_names + if subclass_name != COMPOSITE_ENV_CLASS_NAME +] # Do the full cross product of all the test cases and all the Environment types. @@ -57,11 +61,13 @@ def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_c if try_resolve_class_name(test_case.config.get("class")) == env_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") + f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: """ @@ -76,5 +82,9 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the environment config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 5045bf510b..508787a84b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -25,6 +25,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index e9ee653644..ef5c0edfa3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -33,9 +33,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_optimizer_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_optimizer_class_names # Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). @@ -50,7 +53,9 @@ # Do the full cross product of all the test cases and all the optimizer types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) -def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: +def test_case_coverage_mlos_bench_optimizer_type( + test_case_subtype: str, mlos_bench_optimizer_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench optimizer type. """ @@ -58,7 +63,9 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}") + f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}" + ) + # Being a little lazy for the moment and relaxing the requirement that we have # a subtype test case for each optimizer and space adapter combo. @@ -67,47 +74,58 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) -def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, - mlos_core_optimizer_type: Optional[OptimizerType]) -> None: +def test_case_coverage_mlos_core_optimizer_type( + test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] +) -> None: """ Checks to see if there is a given type of test case for the given mlos_core optimizer type. """ optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): optimizer_type = None if test_case.config.get("config"): optimizer_type = test_case.config["config"].get("optimizer_type", None) if optimizer_type == optimizer_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}") + f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}" + ) @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) -def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, - mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: +def test_case_coverage_mlos_core_space_adapter_type( + test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] +) -> None: """ Checks to see if there is a given type of test case for the given mlos_core space adapter type. """ - space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name + space_adapter_name = ( + None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name + ) for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): space_adapter_type = None if test_case.config.get("config"): space_adapter_type = test_case.config["config"].get("space_adapter_type", None) if space_adapter_type == space_adapter_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}") + f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: """ @@ -122,5 +140,9 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the optimizer config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 8fccba8bc7..23bd17b1e7 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -30,9 +30,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_scheduler_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_scheduler_class_names # Do the full cross product of all the test cases and all the scheduler types. @@ -40,7 +43,9 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) -def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: +def test_case_coverage_mlos_bench_scheduler_type( + test_case_subtype: str, mlos_bench_scheduler_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench scheduler type. """ @@ -48,7 +53,9 @@ def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_be if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}") + f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}" + ) + # Now we actually perform all of those validation tests. @@ -67,8 +74,12 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the scheduler config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 64c6fccccd..032b4c0aad 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -38,16 +38,17 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_SERVICE_CLASSES = { - ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. - AzureDeploymentService, # ABCMeta abstract base class - SshService, # ABCMeta abstract base class + ConfigPersistenceService, # configured thru the launcher cli args + TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + AzureDeploymentService, # ABCMeta abstract base class + SshService, # ABCMeta abstract base class } -expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_SERVICE_CLASSES] +expected_service_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_SERVICE_CLASSES +] assert expected_service_class_names @@ -61,7 +62,7 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] if not isinstance(test_case.config, dict): - continue # type: ignore[unreachable] + continue # type: ignore[unreachable] if "class" not in test_case.config: config_list = test_case.config["services"] else: @@ -70,11 +71,13 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c if try_resolve_class_name(config.get("class")) == service_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for service class {service_class}") + f"Missing test case for subtype {test_case_subtype} for service class {service_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: """ @@ -89,5 +92,9 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the service config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 9b362b5e0d..fd2de83cd0 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -28,9 +28,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_storage_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_storage_class_names # Do the full cross product of all the test cases and all the storage types. @@ -38,7 +41,9 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) -def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: +def test_case_coverage_mlos_bench_storage_type( + test_case_subtype: str, mlos_bench_storage_type: str +) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench storage type. """ @@ -46,11 +51,13 @@ def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_benc if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}") + f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: """ @@ -65,9 +72,15 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the storage config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) - - -if __name__ == '__main__': - pytest.main([__file__, '-n0'],) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) + + +if __name__ == "__main__": + pytest.main( + [__file__, "-n0"], + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index a6d0de9313..11849119c3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -25,6 +25,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index d871eaa212..33124134e9 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -25,6 +25,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 32034eb11c..8431251098 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -25,19 +25,27 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" + def predicate(config_path: str) -> bool: - arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc") + arm_template = config_path.find( + "services/remote/azure/arm-templates/" + ) >= 0 and config_path.endswith(".jsonc") setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0 return not (arm_template or setup_rg_scripts) + return [config_path for config_path in configs_to_filter if predicate(config_path)] -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_service_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) # Make an instance of the class based on the config. diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index 2f9773a9b0..d1d39ec4f5 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -29,12 +29,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_storage_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 58359eb983..304d4903b3 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -42,7 +42,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -59,7 +59,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score", "other_score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -103,7 +103,9 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") -def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: +def docker_services_lock( + shared_temp_dir: str, short_testrun_uid: str +) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker services. @@ -113,7 +115,9 @@ def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterP A lock to ensure that setup/teardown operations don't happen while a worker is using the docker services. """ - return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock") + return InterProcessReaderWriterLock( + f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") @@ -126,7 +130,9 @@ def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> ------ A lock to ensure that only one worker is doing setup/teardown at a time. """ - return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock") + return InterProcessLock( + f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index ac0b942167..8218577986 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -16,11 +16,13 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def check_env_success(env: Environment, - tunable_groups: TunableGroups, - expected_results: Dict[str, TunableValue], - expected_telemetry: List[Tuple[datetime, str, Any]], - global_config: Optional[dict] = None) -> None: +def check_env_success( + env: Environment, + tunable_groups: TunableGroups, + expected_results: Dict[str, TunableValue], + expected_telemetry: List[Tuple[datetime, str, Any]], + global_config: Optional[dict] = None, +) -> None: """ Set up an environment and run a test experiment there. @@ -50,7 +52,7 @@ def check_env_success(env: Environment, assert telemetry == pytest.approx(expected_telemetry, nan_ok=True) env_context.teardown() - assert not env_context._is_ready # pylint: disable=protected-access + assert not env_context._is_ready # pylint: disable=protected-access def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 8afb8e5cda..7be966d482 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -28,9 +28,13 @@ def test_expand_groups() -> None: """ Check the dollar variable expansion for tunable groups. """ - assert Environment._expand_groups( - ["begin", "$list", "$empty", "$str", "end"], - _GROUPS) == ["begin", "c", "d", "efg", "end"] + assert Environment._expand_groups(["begin", "$list", "$empty", "$str", "end"], _GROUPS) == [ + "begin", + "c", + "d", + "efg", + "end", + ] def test_expand_groups_empty_input() -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index 6497eb6985..f7e0e86795 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -40,20 +40,20 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "name": "Env 3 :: tmp_other_3", "class": "mlos_bench.environments.mock_env.MockEnv", "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], - } + }, ] }, tunables=tunable_groups, service=LocalExecService( - config={ - "temp_dir": "_test_tmp_global" - }, - parent=ConfigPersistenceService({ - "config_path": [ - path_join(os.path.dirname(__file__), "../config", abs_path=True), - ] - }) - ) + config={"temp_dir": "_test_tmp_global"}, + parent=ConfigPersistenceService( + { + "config_path": [ + path_join(os.path.dirname(__file__), "../config", abs_path=True), + ] + } + ), + ), ) @@ -61,7 +61,7 @@ def test_composite_services(composite_env: CompositeEnv) -> None: """ Check that each environment gets its own instance of the services. """ - for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): + for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") with service.temp_dir_context() as temp_dir: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 742eaf3c79..1a159ef4ef 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -28,7 +28,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", "someConst": "root", - "global_param": "default" + "global_param": "default", }, "children": [ { @@ -43,7 +43,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "someConst", "global_param"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Server Environment 2", @@ -53,12 +53,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vmName": "$vm_server_name", "EnvId": 2, - "global_param": "local" + "global_param": "local", }, "required_args": ["vmName"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Control Environment 3", @@ -72,15 +72,13 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "vm_server_name", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } - } - ] + }, + }, + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -90,26 +88,26 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", - "vm_server_name": "Mock Server VM" + "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config } @@ -118,33 +116,35 @@ def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: Tunabl """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with composite_env as env_context: assert env_context.setup(tunable_groups) assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "mwait", # tunable_params from the parent "vm_client_name": "Mock Client VM", "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config @@ -163,7 +163,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", - "someConst": "root" + "someConst": "root", }, "children": [ { @@ -191,11 +191,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "EnvId", "someConst", "vm_server_name", - "global_param" + "global_param", ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], @@ -220,20 +220,17 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "EnvId", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], }, }, - - ] + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -244,52 +241,56 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None """ assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: +def test_nested_composite_env_setup( + nested_composite_env: CompositeEnv, tunable_groups: TunableGroups +) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with nested_composite_env as env_context: assert env_context.setup(tunable_groups) assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", } diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index 7395aa3e15..cbfd6d75ed 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -16,9 +16,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: Make sure only one tunable group is available to the environment. """ env = MockEnv( - name="Test Env", - config={"tunable_params": ["provision"]}, - tunables=tunable_groups + name="Test Env", config={"tunable_params": ["provision"]}, tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -32,7 +30,7 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -55,7 +53,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) expected_params = { "vmSize": "Standard_B4ms", @@ -80,11 +78,7 @@ def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: """ Make sure that no tunable groups are available to the environment by default. """ - env = MockEnv( - name="Test Env", - config={}, - tunables=tunable_groups - ) + env = MockEnv(name="Test Env", config={}, tunables=tunable_groups) assert env.tunable_params.get_param_values() == {} @@ -93,11 +87,7 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: Make sure that no tunable groups are available to the environment when explicitly specifying an empty list of tunable_params. """ - env = MockEnv( - name="Test Env", - config={"tunable_params": []}, - tunables=tunable_groups - ) + env = MockEnv(name="Test Env", config={"tunable_params": []}, tunables=tunable_groups) assert env.tunable_params.get_param_values() == {} @@ -114,7 +104,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == {} @@ -137,9 +127,7 @@ def test_loader_level_include() -> None: env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", "name": "Test Env", - "include_tunables": [ - "environments/os/linux/boot/linux-boot-tunables.jsonc" - ], + "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"], "config": { "tunable_params": ["linux-kernel-boot"], "const_args": { @@ -148,12 +136,14 @@ def test_loader_level_include() -> None: }, }, } - loader = ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - }) + loader = ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) env = loader.build_environment(config=env_json, tunables=TunableGroups()) expected_params = { "align_va_addr": "on", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index 5d8fc32c6b..c68d2fa7b8 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -32,14 +32,20 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L env : LocalEnv A new instance of the local environment. """ - return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups, - service=LocalExecService(parent=ConfigPersistenceService())) + return LocalEnv( + name="TestLocalEnv", + config=config, + tunables=tunable_groups, + service=LocalExecService(parent=ConfigPersistenceService()), + ) -def create_composite_local_env(tunable_groups: TunableGroups, - global_config: Dict[str, Any], - params: Dict[str, Any], - local_configs: List[Dict[str, Any]]) -> CompositeEnv: +def create_composite_local_env( + tunable_groups: TunableGroups, + global_config: Dict[str, Any], + params: Dict[str, Any], + local_configs: List[Dict[str, Any]], +) -> CompositeEnv: """ Create a CompositeEnv with several LocalEnv instances. @@ -70,7 +76,7 @@ def create_composite_local_env(tunable_groups: TunableGroups, "config": config, } for (i, config) in enumerate(local_configs) - ] + ], }, tunables=tunable_groups, global_config=global_config, diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 9bcb7aa218..83dcc3ce5d 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -43,7 +43,7 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "") + (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "") env = create_composite_local_env( tunable_groups=tunable_groups, @@ -67,8 +67,8 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo "required_args": ["errors", "reads"], "shell_env_params": [ "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads" # const_args overridden by the global config + "errors", # Comes from the parent const_args + "reads", # const_args overridden by the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -90,9 +90,9 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes" # Comes straight from the global config + "throughput", # const_args overridden by the composite env + "score", # Comes from the local const_args + "writes", # Comes straight from the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -106,12 +106,13 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo ], "read_results_file": "output.csv", "read_telemetry_file": "telemetry.csv", - } - ] + }, + ], ) check_env_success( - env, tunable_groups, + env, + tunable_groups, expected_results={ "latency": 4.2, "throughput": 768.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index 20854b9f9e..bdcd9f885f 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -17,19 +17,23 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", - ], - "results_stdout_pattern": r"(\w+),([0-9.]+)", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", + ], + "results_stdout_pattern": r"(\w+),([0-9.]+)", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -45,19 +49,23 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern - ], - "results_stdout_pattern": r"^(\w+),([0-9.]+)$", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern + ], + "results_stdout_pattern": r"^(\w+),([0-9.]+)$", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -72,24 +80,28 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to *BOTH* stdout and a file and extract the results from both. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'stdout-msg,string'", - "echo '-------------------'", # Should be ignored - "echo 'metric,value' > output.csv", - "echo 'extra1,333' >> output.csv", - "echo 'extra2,444' >> output.csv", - "echo 'file-msg,string' >> output.csv", - ], - "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'stdout-msg,string'", + "echo '-------------------'", # Should be ignored + "echo 'metric,value' > output.csv", + "echo 'extra1,333' >> output.csv", + "echo 'extra2,444' >> output.csv", + "echo 'file-msg,string' >> output.csv", + ], + "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 35bdb39486..2491e89e24 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -37,25 +37,29 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,4.1' >> output.csv", - "echo 'throughput,512' >> output.csv", - "echo 'score,0.95' >> output.csv", - "echo '-------------------'", # This output does not go anywhere - "echo 'timestamp,metric,value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_results_file": "output.csv", - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,4.1' >> output.csv", + "echo 'throughput,512' >> output.csv", + "echo 'score,0.95' >> output.csv", + "echo '-------------------'", # This output does not go anywhere + "echo 'timestamp,metric,value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_results_file": "output.csv", + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 4.1, "throughput": 512.0, @@ -72,7 +76,9 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_no_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Read the telemetry data with no header. """ @@ -84,18 +90,22 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - f"echo {time_str1},cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + f"echo {time_str1},cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={}, expected_telemetry=[ (ts1.astimezone(UTC), "cpu_load", 0.65), @@ -106,9 +116,13 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: ) -@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa +@pytest.mark.filterwarnings( + "ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0" +) # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_wrong_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Read the telemetry data with incorrect header. """ @@ -120,17 +134,20 @@ def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_in time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: the data is correct, but the header has unexpected column names - "echo 'ts,metric_name,metric_value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: the data is correct, but the header has unexpected column names + "echo 'ts,metric_name,metric_value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) @@ -148,16 +165,19 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: too many columns - f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: too many columns + f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) @@ -166,15 +186,18 @@ def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: """ Fail when the telemetry data has wrong format. """ - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: field 1 must be a timestamp - "echo 1,cpu_load,0.65 > telemetry.csv", - "echo 2,mem_usage,10240 >> telemetry.csv", - "echo 3,cpu_load,0.8 >> telemetry.csv", - "echo 4,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: field 1 must be a timestamp + "echo 1,cpu_load,0.65 > telemetry.csv", + "echo 2,mem_usage,10240 >> telemetry.csv", + "echo 3,cpu_load,0.8 >> telemetry.csv", + "echo 4,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 6cb4fd4f7e..2b51ae1f0e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -16,18 +16,22 @@ def test_local_env(tunable_groups: TunableGroups) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,10' >> output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,10' >> output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10.0, "throughput": 66.0, @@ -41,9 +45,7 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ Basic check that context support for Service mixins are handled when environment contexts are entered. """ - local_env = create_local_env(tunable_groups, { - "run": ["echo NA"] - }) + local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) # pylint: disable=protected-access assert local_env._service assert not local_env._service._in_context @@ -51,10 +53,10 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: with local_env as env_context: assert env_context._in_context assert local_env._service._in_context - assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) + assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) assert all(svc._in_context for svc in local_env._service._service_contexts) assert all(svc._in_context for svc in local_env._service._services) - assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) + assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) assert not local_env._service._service_contexts assert not any(svc._in_context for svc in local_env._service._services) @@ -63,15 +65,18 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: """ Fail if the results are not in the expected format. """ - local_env = create_local_env(tunable_groups, { - "run": [ - # No header - "echo 'latency,10' > output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # No header + "echo 'latency,10' > output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) with local_env as env_context: assert env_context.setup(tunable_groups) @@ -83,16 +88,20 @@ def test_local_env_wide(tunable_groups: TunableGroups) -> None: """ Produce benchmark data in wide format and read it. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,throughput,score' > output.csv", - "echo '10,66,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,throughput,score' > output.csv", + "echo '10,66,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10, "throughput": 66, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index c16eac4459..c6ece538f1 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -18,27 +18,30 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: d """ Check that LocalEnv can set shell environment variables. """ - local_env = create_local_env(tunable_groups, { - "const_args": { - "const_arg": 111, # Passed into "shell_env_params" - "other_arg": 222, # NOT passed into "shell_env_params" + local_env = create_local_env( + tunable_groups, + { + "const_args": { + "const_arg": 111, # Passed into "shell_env_params" + "other_arg": 222, # NOT passed into "shell_env_params" + }, + "tunable_params": ["kernel"], + "shell_env_params": [ + "const_arg", # From "const_arg" + "kernel_sched_latency_ns", # From "tunable_params" + ], + "run": [ + "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", + f"echo {shell_subcmd} >> output.csv", + ], + "read_results_file": "output.csv", }, - "tunable_params": ["kernel"], - "shell_env_params": [ - "const_arg", # From "const_arg" - "kernel_sched_latency_ns", # From "tunable_params" - ], - "run": [ - "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", - f"echo {shell_subcmd} >> output.csv", - ], - "read_results_file": "output.csv", - }) + ) check_env_success(local_env, tunable_groups, expected, []) -@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") +@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables in sh-like shell. @@ -47,15 +50,15 @@ def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", expected={ - "const_arg": 111, # From "const_args" - "other_arg": float("NaN"), # Not included in "shell_env_params" - "unknown_arg": float("NaN"), # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": float("NaN"), # Not included in "shell_env_params" + "unknown_arg": float("NaN"), # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) -@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") +@pytest.mark.skipif(sys.platform != "win32", reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables on Windows / cmd shell. @@ -64,9 +67,9 @@ def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%", expected={ - "const_arg": 111, # From "const_args" - "other_arg": r"%other_arg%", # Not included in "shell_env_params" - "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": r"%other_arg%", # Not included in "shell_env_params" + "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 8bce053f7b..25e75cf748 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -25,13 +25,14 @@ def mock_fileshare_service() -> MockFileShareService: """ return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, - parent=LocalExecService(parent=ConfigPersistenceService()) + parent=LocalExecService(parent=ConfigPersistenceService()), ) @pytest.fixture -def local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: +def local_fileshare_env( + tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService +) -> LocalFileShareEnv: """ Create a LocalFileShareEnv instance. """ @@ -40,12 +41,12 @@ def local_fileshare_env(tunable_groups: TunableGroups, config={ "const_args": { "experiment_id": "EXP_ID", # Passed into "shell_env_params" - "trial_id": 222, # NOT passed into "shell_env_params" + "trial_id": 222, # NOT passed into "shell_env_params" }, "tunable_params": ["boot"], "shell_env_params": [ - "trial_id", # From "const_arg" - "idle", # From "tunable_params", == "halt" + "trial_id", # From "const_arg" + "idle", # From "tunable_params", == "halt" ], "upload": [ { @@ -57,9 +58,7 @@ def local_fileshare_env(tunable_groups: TunableGroups, "to": "$experiment_id/$trial_id/input/data_$idle.csv", }, ], - "run": [ - "echo No-op run" - ], + "run": ["echo No-op run"], "download": [ { "from": "$experiment_id/$trial_id/$idle/data.csv", @@ -73,9 +72,11 @@ def local_fileshare_env(tunable_groups: TunableGroups, return env -def test_local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService, - local_fileshare_env: LocalFileShareEnv) -> None: +def test_local_fileshare_env( + tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, + local_fileshare_env: LocalFileShareEnv, +) -> None: """ Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the upload and download sections of the config. diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index 608edbf9ef..c536c97a89 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -42,20 +42,22 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr assert data["score"] == pytest.approx(75.0, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 66.4), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 74.06), -]) -def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, + 66.4, + ), + ( + {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, + 74.06, + ), + ], +) +def test_mock_env_assign( + mock_env: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float +) -> None: """ Check the benchmark values of the mock environment after the assignment. """ @@ -68,21 +70,25 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, assert data["score"] == pytest.approx(expected_score, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 67.5), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 75.1), -]) -def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, + 67.5, + ), + ( + {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, + 75.1, + ), + ], +) +def test_mock_env_no_noise_assign( + mock_env_no_noise: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, + expected_score: float, +) -> None: """ Check the benchmark values of the noiseless mock environment after the assignment. """ diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 878531d799..6d47d1fc61 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -38,25 +38,31 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: "ssh_priv_key_path": ssh_test_server.id_rsa_path, } - service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) + service = ConfigPersistenceService( + config={"config_path": [str(files("mlos_bench.tests.config"))]} + ) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") - env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) + env = service.load_environment( + config_path, TunableGroups(), global_config=global_config, service=service + ) check_env_success( - env, env.tunable_params, + env, + env.tunable_params, expected_results={ "hostname": ssh_test_server.service_name, "username": ssh_test_server.username, "score": 0.9, - "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" + "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" "test_param": "unset", "FOO": "unset", "ssh_username": "unset", }, expected_telemetry=[], ) - assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ - "output-downloaded.csv should have been cleaned up by temp_dir context" + assert not os.path.exists( + os.path.join(os.getcwd(), "output-downloaded.csv") + ), "output-downloaded.csv should have been cleaned up by temp_dir context" if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index 377bc940a0..b95666824a 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -40,16 +40,21 @@ def __enter__(self) -> None: self.EVENT_LOOP_CONTEXT.enter() self._in_context = True - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: assert self._in_context self.EVENT_LOOP_CONTEXT.exit() self._in_context = False return False -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_event_loop_context() -> None: """Test event loop context background thread setup/cleanup handling.""" # pylint: disable=protected-access,too-many-statements @@ -87,12 +92,16 @@ def test_event_loop_context() -> None: assert event_loop_caller_instance_1._in_context assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 # We should only get one thread for all instances. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ + ) + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop + ) assert not event_loop_caller_instance_2._in_context @@ -104,30 +113,38 @@ def test_event_loop_context() -> None: assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 # Once we exit the last context, the background thread should be stopped # and unusable for running co-routines. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() # Check that the event loop has no more tasks. - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready") # Windows ProactorEventLoopPolicy adds a dummy task. - if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop): + if sys.platform == "win32" and isinstance( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop + ): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) raise ValueError(f"Future should not have been available to wait on {future.result()}") # Test that when re-entering the context we have the same event loop. @@ -138,12 +155,14 @@ def test_event_loop_context() -> None: # Test running again. start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 90aa7e08f7..25abf659ce 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -14,19 +14,33 @@ @pytest.mark.parametrize( - ("argv", "expected_score"), [ - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", - "--trial_config_repeat_count", "5", - "--mock_env_seed", "-1", # Deterministic Mock Environment. - ], 67.40329), - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", - "--trial_config_repeat_count", "3", - "--max_suggestions", "3", - "--mock_env_seed", "42", # Noisy Mock Environment. - ], 64.53897), - ] + ("argv", "expected_score"), + [ + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", + "--trial_config_repeat_count", + "5", + "--mock_env_seed", + "-1", # Deterministic Mock Environment. + ], + 67.40329, + ), + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", + "--trial_config_repeat_count", + "3", + "--max_suggestions", + "3", + "--mock_env_seed", + "42", # Noisy Mock Environment. + ], + 64.53897, + ), + ], ) def test_main_bench(argv: List[str], expected_score: float) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 634050d099..b03c5a2733 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -48,8 +48,8 @@ def config_paths() -> List[str]: """ return [ path_join(os.getcwd(), abs_path=True), - str(files('mlos_bench.config')), - str(files('mlos_bench.tests.config')), + str(files("mlos_bench.config")), + str(files("mlos_bench.tests.config")), ] @@ -64,20 +64,23 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] + environ["USER"] = environ["USERNAME"] # This is part of the minimal required args by the Launcher. - env_conf_path = 'environments/mock/mock_env.jsonc' - cli_args = '--config-paths ' + ' '.join(config_paths) + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - ' --scheduler schedulers/sync_scheduler.jsonc' + \ - f' --environment {env_conf_path}' + \ - ' --globals globals/global_test_config.jsonc' + \ - ' --globals globals/global_test_extra_config.jsonc' \ - ' --test_global_value_2 from-args' + env_conf_path = "environments/mock/mock_env.jsonc" + cli_args = ( + "--config-paths " + + " ".join(config_paths) + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + " --scheduler schedulers/sync_scheduler.jsonc" + + f" --environment {env_conf_path}" + + " --globals globals/global_test_config.jsonc" + + " --globals globals/global_test_extra_config.jsonc" + " --test_global_value_2 from-args" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -85,27 +88,28 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsLocalExec) assert isinstance(launcher.service, SupportsRemoteExec) # Check that the first --globals file is loaded and $var expansion is handled. - assert launcher.global_config['experiment_id'] == 'MockExperiment' - assert launcher.global_config['testVmName'] == 'MockExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockExperiment" + assert launcher.global_config["testVmName"] == "MockExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" # Check that the second --globals file is loaded. - assert launcher.global_config['test_global_value'] == 'from-file' + assert launcher.global_config["test_global_value"] == "from-file" # Check overriding values in a file from the command line. - assert launcher.global_config['test_global_value_2'] == 'from-args' + assert launcher.global_config["test_global_value_2"] == "from-args" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( + os.getcwd(), "foo", abs_path=True + ) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert launcher.teardown # Check that the environment that got loaded looks to be of the right type. env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. assert launcher.optimizer.tunable_params.is_defaults() - assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer + assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access @@ -122,23 +126,25 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] - - config_file = 'cli/test-cli-config.jsonc' - globals_file = 'globals/global_test_config.jsonc' - cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \ - f' --config {config_file}' + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - f' --globals {globals_file}' + \ - ' --experiment_id MockeryExperiment' + \ - ' --no-teardown' + \ - ' --random-init' + \ - ' --random-seed 1234' + \ - ' --trial-config-repeat-count 5' + \ - ' --max_trials 200' + environ["USER"] = environ["USERNAME"] + + config_file = "cli/test-cli-config.jsonc" + globals_file = "globals/global_test_config.jsonc" + cli_args = ( + " ".join([f"--config-path {config_path}" for config_path in config_paths]) + + f" --config {config_file}" + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + f" --globals {globals_file}" + + " --experiment_id MockeryExperiment" + + " --no-teardown" + + " --random-init" + + " --random-seed 1234" + + " --trial-config-repeat-count 5" + + " --max_trials 200" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -148,35 +154,42 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsRemoteExec) # Check that the --globals file is loaded and $var expansion is handled # using the value provided on the CLI. - assert launcher.global_config['experiment_id'] == 'MockeryExperiment' - assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockeryExperiment" + assert launcher.global_config["testVmName"] == "MockeryExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( + os.getcwd(), "foo", abs_path=True + ) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert not launcher.teardown config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI) - assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']] + assert launcher.config_loader.config_paths == [ + path_join(path, abs_path=True) for path in config_paths + config["config_path"] + ] # Check that the environment that got loaded looks to be of the right type. - env_config_file = config['environment'] + env_config_file = config["environment"] env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) - opt_config_file = config['optimizer'] + opt_config_file = config["optimizer"] opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER) globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS) # The actual global_config gets overwritten as a part of processing, so to test # this we read the original value out of the source files. - orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100)) - assert launcher.optimizer.max_iterations \ - == orig_max_iters \ - == launcher.global_config['max_suggestions'] + orig_max_iters = globals_file_config.get( + "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100) + ) + assert ( + launcher.optimizer.max_iterations + == orig_max_iters + == launcher.global_config["max_suggestions"] + ) # Check that the optimizer got initialized with random values instead of the defaults. # Note: the environment doesn't get updated until suggest() is called to @@ -193,12 +206,12 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. - assert config['random_seed'] == 42 + assert config["random_seed"] == 42 # TODO: This isn't actually respected yet because the `--random-init` only # applies to a temporary Optimizer used to populate the initial values via # random sampling. # assert launcher.optimizer.seed == 1234 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__, "-n1"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 591501d275..8fff9b5dd5 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -31,16 +31,21 @@ def local_exec_service() -> LocalExecService: """ Test fixture for LocalExecService. """ - return LocalExecService(parent=ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - })) + return LocalExecService( + parent=ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) + ) -def _launch_main_app(root_path: str, local_exec_service: LocalExecService, - cli_config: str, re_expected: List[str]) -> None: +def _launch_main_app( + root_path: str, local_exec_service: LocalExecService, cli_config: str, re_expected: List[str] +) -> None: """ Run mlos_bench command-line application with given config and check the results in the log. @@ -52,10 +57,13 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService, # temp_dir = '/tmp' log_path = path_join(temp_dir, "mock-test.log") (return_code, _stdout, _stderr) = local_exec_service.local_exec( - ["./mlos_bench/mlos_bench/run.py" + - " --config_path ./mlos_bench/mlos_bench/tests/config/" + - f" {cli_config} --log_file '{log_path}'"], - cwd=root_path) + [ + "./mlos_bench/mlos_bench/run.py" + + " --config_path ./mlos_bench/mlos_bench/tests/config/" + + f" {cli_config} --log_file '{log_path}'" + ], + cwd=root_path, + ) assert return_code == 0 try: @@ -79,33 +87,34 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ and default tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", + ], ) def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService) -> None: + root_path: str, local_exec_service: LocalExecService +) -> None: """ Run mlos_bench command-line application with mock benchmark config and user-specified tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --tunable_values tunable-values/tunable-values-example.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --tunable_values tunable-values/tunable-values-example.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", + ], ) @@ -115,23 +124,23 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - "--config cli/mock-opt.jsonc" + - " --trial_config_repeat_count 3" + - " --max_suggestions 3" + - " --mock_env_seed 42", # Noisy Mock Environment. + root_path, + local_exec_service, + "--config cli/mock-opt.jsonc" + + " --trial_config_repeat_count 3" + + " --max_suggestions 3" + + " --mock_env_seed 42", # Noisy Mock Environment. [ # Iteration 1: Expect first value to be the baseline - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", # Iteration 2: The result may not always be deterministic - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Iteration 3: non-deterministic (depends on the optimizer) - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Final result: baseline is the optimum for the mock environment - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", + ], ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 59a0fac13b..924224365c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -23,29 +23,29 @@ def mock_configs() -> List[dict]: """ return [ { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_latency_ns': 1000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_latency_ns": 1000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_latency_ns': 2000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_latency_ns": 2000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': -1, # Special value - 'kernel_sched_latency_ns': 3000000, + "vmSize": "Standard_B4ms", + "idle": "mwait", + "kernel_sched_migration_cost_ns": -1, # Special value + "kernel_sched_latency_ns": 3000000, }, { - 'vmSize': 'Standard_B2s', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_latency_ns': 4000000, - } + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_latency_ns": 4000000, + }, ] @@ -61,7 +61,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: "optimization_targets": {"score": "min"}, "max_suggestions": 5, "start_with_defaults": False, - "seed": SEED + "seed": SEED, }, ) @@ -74,11 +74,7 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={ - "optimization_targets": {"score": "min"}, - "max_suggestions": 5, - "seed": SEED - }, + config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED}, ) @@ -90,11 +86,7 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={ - "optimization_targets": {"score": "max"}, - "max_suggestions": 10, - "seed": SEED - }, + config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED}, ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 9e9ce25d6f..add2945d74 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -20,6 +20,7 @@ # pylint: disable=redefined-outer-name + @pytest.fixture def grid_search_tunables_config() -> dict: """ @@ -51,14 +52,22 @@ def grid_search_tunables_config() -> dict: @pytest.fixture -def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: +def grid_search_tunables_grid( + grid_search_tunables: TunableGroups, +) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. Used to check that the grids are the same (ignoring order). """ - tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] - tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) - return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) + tunables_params_values = [ + tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None + ] + tunable_names = tuple( + tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None + ) + return list( + dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values) + ) @pytest.fixture @@ -70,8 +79,9 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture -def grid_search_opt(grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: +def grid_search_opt( + grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]] +) -> GridSearchOptimizer: """ Test fixture for grid search optimizer. """ @@ -79,15 +89,20 @@ def grid_search_opt(grid_search_tunables: TunableGroups, # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. max_iterations = len(grid_search_tunables_grid) * 2 - 3 - return GridSearchOptimizer(tunables=grid_search_tunables, config={ - "max_suggestions": max_iterations, - "optimization_targets": {"score": "max", "other_score": "min"}, - }) + return GridSearchOptimizer( + tunables=grid_search_tunables, + config={ + "max_suggestions": max_iterations, + "optimization_targets": {"score": "max", "other_score": "min"}, + }, + ) -def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search_grid( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -114,9 +129,11 @@ def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, # assert grid_search_opt.pending_configs == grid_search_tunables_grid -def test_grid_search(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -143,7 +160,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables_grid.remove(default_config) assert default_config not in grid_search_opt.pending_configs assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid + ) # The next suggestion should be a different element in the grid search. suggestion = grid_search_opt.suggest() @@ -157,7 +176,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables_grid.remove(suggestion.get_param_values()) assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid + ) # We consider not_converged as either having reached "max_suggestions" or an empty grid? @@ -223,7 +244,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: assert best_suggestion_dict not in grid_search_opt.suggested_configs best_suggestion_score: Dict[str, TunableValue] = {} - for (opt_target, opt_dir) in grid_search_opt.targets.items(): + for opt_target, opt_dir in grid_search_opt.targets.items(): val = score[opt_target] assert isinstance(val, (int, float)) best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 @@ -237,36 +258,54 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: # Check bulk register suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() in grid_search_opt.suggested_configs + for suggestion in suggested + ) # Those new suggestions also shouldn't be in the set of previously suggested configs. assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) - grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], - [score] * len(suggested), - [status] * len(suggested)) - - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) + grid_search_opt.bulk_register( + [suggestion.get_param_values() for suggestion in suggested], + [score] * len(suggested), + [status] * len(suggested), + ) + + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() not in grid_search_opt.suggested_configs + for suggestion in suggested + ) best_score, best_config = grid_search_opt.get_best_observation() assert best_score == best_suggestion_score assert best_config == best_suggestion -def test_grid_search_register(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups) -> None: +def test_grid_search_register( + grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups +) -> None: """ Make sure that the `.register()` method adjusts the score signs correctly. """ assert grid_search_opt.register( - grid_search_tunables, Status.SUCCEEDED, { + grid_search_tunables, + Status.SUCCEEDED, + { "score": 1.0, "other_score": 2.0, - }) == { - "score": -1.0, # max - "other_score": 2.0, # min + }, + ) == { + "score": -1.0, # max + "other_score": 2.0, # min } assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 6549a8795c..3a0ef7db2e 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -34,7 +34,8 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: "optimizer_type": "SMAC", "seed": SEED, # "start_with_defaults": False, - }) + }, + ) @pytest.fixture @@ -61,6 +62,6 @@ def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list assert best_score["score"] == pytest.approx(66.66, 0.01) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__]) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index 7ebba0e664..c824d9774f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -24,9 +24,9 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: An instance of a mlos_core optimizer (FLAML-based). """ test_opt_config = { - 'optimizer_type': 'FLAML', - 'max_suggestions': 10, - 'seed': SEED, + "optimizer_type": "FLAML", + "max_suggestions": 10, + "seed": SEED, } return MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -39,44 +39,44 @@ def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) assert set(df_config.columns) == { - 'kernel_sched_latency_ns', - 'kernel_sched_migration_cost_ns', - 'kernel_sched_migration_cost_ns!type', - 'kernel_sched_migration_cost_ns!special', - 'idle', - 'vmSize', + "kernel_sched_latency_ns", + "kernel_sched_migration_cost_ns", + "kernel_sched_migration_cost_ns!type", + "kernel_sched_migration_cost_ns!special", + "idle", + "vmSize", } - assert df_config.to_dict(orient='records') == [ + assert df_config.to_dict(orient="records") == [ { - 'idle': 'halt', - 'kernel_sched_latency_ns': 1000000, - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 1000000, + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'halt', - 'kernel_sched_latency_ns': 2000000, - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 2000000, + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 3000000, - 'kernel_sched_migration_cost_ns': None, # The value is special! - 'kernel_sched_migration_cost_ns!special': -1, - 'kernel_sched_migration_cost_ns!type': 'special', - 'vmSize': 'Standard_B4ms', + "idle": "mwait", + "kernel_sched_latency_ns": 3000000, + "kernel_sched_migration_cost_ns": None, # The value is special! + "kernel_sched_migration_cost_ns!special": -1, + "kernel_sched_migration_cost_ns!type": "special", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 4000000, - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B2s', + "idle": "mwait", + "kernel_sched_latency_ns": 4000000, + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B2s", }, ] diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index fc62b4ff1b..9d696e01fa 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -17,8 +17,8 @@ from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' -_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test. +_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/" +_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test. def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: @@ -26,10 +26,10 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) Test invalid max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_trials': 10, - 'max_suggestions': 11, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_trials": 10, + "max_suggestions": 11, + "seed": SEED, } with pytest.raises(AssertionError): opt = MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -41,14 +41,14 @@ def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> No Test max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_suggestions': 123, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_suggestions": 123, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) - assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions'] + assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"] def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: @@ -57,9 +57,9 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': output_dir, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": output_dir, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) @@ -67,7 +67,8 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr assert isinstance(opt._opt, SmacOptimizer) # Final portions of the path are generated by SMAC when run_name is not specified. assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - str(test_opt_config['output_directory'])) + str(test_opt_config["output_directory"]) + ) shutil.rmtree(output_dir) @@ -76,56 +77,67 @@ def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGr Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']))) + path_join(os.getcwd(), str(test_opt_config["output_directory"])) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_run_name( + tunable_groups: TunableGroups, +) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'run_name': 'test_run', - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "run_name": "test_run", + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name']))) + path_join( + os.getcwd(), str(test_opt_config["output_directory"]), str(test_opt_config["run_name"]) + ) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( + tunable_groups: TunableGroups, +) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } global_config = { - 'experiment_id': 'experiment_id', + "experiment_id": "experiment_id", } opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id'])) + path_join( + os.getcwd(), str(test_opt_config["output_directory"]), global_config["experiment_id"] + ) + ) shutil.rmtree(_OUTPUT_DIR) @@ -134,9 +146,9 @@ def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups Test random output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': None, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": None, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index a94a315939..b95d943272 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -20,24 +20,33 @@ def mock_configurations_no_defaults() -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, - }, 88.88), - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 117026, - "kernel_sched_latency_ns": 149827706, - }, 66.66), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 354785, - "kernel_sched_latency_ns": 795285932, - }, 99.99), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 13112, + "kernel_sched_latency_ns": 796233790, + }, + 88.88, + ), + ( + { + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 117026, + "kernel_sched_latency_ns": 149827706, + }, + 66.66, + ), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 354785, + "kernel_sched_latency_ns": 795285932, + }, + 99.99, + ), ] @@ -47,12 +56,15 @@ def mock_configurations(mock_configurations_no_defaults: list) -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 2000000, - }, 88.88), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": -1, + "kernel_sched_latency_ns": 2000000, + }, + 88.88, + ), ] + mock_configurations_no_defaults @@ -60,7 +72,7 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """ Run several iterations of the optimizer and return the best score. """ - for (tunable_values, score) in mock_configurations: + for tunable_values, score in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() assert tunables.get_param_values() == tunable_values @@ -80,8 +92,9 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, - mock_configurations_no_defaults: list) -> None: +def test_mock_optimizer_no_defaults( + mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list +) -> None: """ Make sure that mock optimizer produces consistent suggestions. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index bf37040f13..ccc0ba8137 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -25,10 +25,7 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: Same as `mock_config` above, but with all values converted to strings. (This can happen when we retrieve the data from storage). """ - return [ - {key: str(val) for (key, val) in config.items()} - for config in mock_configs - ] + return [{key: str(val) for (key, val) in config.items()} for config in mock_configs] @pytest.fixture @@ -52,10 +49,12 @@ def mock_status() -> List[Status]: return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] -def _test_opt_update_min(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_min( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """ Test the bulk update of the optimizer on the minimization problem. """ @@ -68,14 +67,16 @@ def _test_opt_update_min(opt: Optimizer, "vmSize": "Standard_B4ms", "idle": "mwait", "kernel_sched_migration_cost_ns": -1, - 'kernel_sched_latency_ns': 3000000, + "kernel_sched_latency_ns": 3000000, } -def _test_opt_update_max(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_max( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """ Test the bulk update of the optimizer on the maximization problem. """ @@ -88,14 +89,16 @@ def _test_opt_update_max(opt: Optimizer, "vmSize": "Standard_B2s", "idle": "mwait", "kernel_sched_migration_cost_ns": 200000, - 'kernel_sched_latency_ns': 4000000, + "kernel_sched_latency_ns": 4000000, } -def test_update_mock_min(mock_opt: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min( + mock_opt: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer on the minimization problem. """ @@ -105,64 +108,76 @@ def test_update_mock_min(mock_opt: MockOptimizer, "vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 13112, - 'kernel_sched_latency_ns': 796233790, + "kernel_sched_latency_ns": 796233790, } -def test_update_mock_min_str(mock_opt: MockOptimizer, - mock_configs_str: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min_str( + mock_opt: MockOptimizer, + mock_configs_str: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer with all-strings data. """ _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) -def test_update_mock_max(mock_opt_max: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_max( + mock_opt_max: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the mock optimizer on the maximization problem. """ _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) -def test_update_flaml(flaml_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml( + flaml_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) -def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml_max( + flaml_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) -def test_update_smac(smac_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac( + smac_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the SMAC optimizer. """ _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) -def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac_max( + smac_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """ Test the bulk update of the SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index 2a50f95e8c..c30d1c32d2 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -56,7 +56,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: (status, _ts, output) = env_context.run() assert status.is_succeeded() assert output is not None - score = output['score'] + score = output["score"] assert isinstance(score, float) assert 60 <= score <= 120 logger("score: %s", str(score)) @@ -69,8 +69,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: return (best_score["score"], best_tunables) -def test_mock_optimization_loop(mock_env_no_noise: MockEnv, - mock_opt: MockOptimizer) -> None: +def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -84,8 +83,9 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, } -def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, - mock_opt_no_defaults: MockOptimizer) -> None: +def test_mock_optimization_loop_no_defaults( + mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer +) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -99,8 +99,7 @@ def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, } -def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, - flaml_opt: MlosCoreOptimizer) -> None: +def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and FLAML optimizer. """ @@ -115,8 +114,7 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, # @pytest.mark.skip(reason="SMAC is not deterministic") -def test_smac_optimization_loop(mock_env_no_noise: MockEnv, - smac_opt: MlosCoreOptimizer) -> None: +def test_smac_optimization_loop(mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index 1971c01799..bf4df0e6c2 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -11,8 +11,8 @@ from .remote import MockFileShareService, MockRemoteExecService, MockVMService __all__ = [ - 'MockLocalExecService', - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockLocalExecService", + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index d6cb869f09..881b6b6cfa 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -29,15 +29,19 @@ def config_persistence_service() -> ConfigPersistenceService: """ Test fixture for ConfigPersistenceService. """ - return ConfigPersistenceService({ - "config_path": [ - "./non-existent-dir/test/foo/bar", # Non-existent config path - ".", # cwd - str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) - # Shouldn't be necessary since we automatically add this. - # str(files("mlos_bench.config").joinpath("")), # Stock configs - ] - }) + return ConfigPersistenceService( + { + "config_path": [ + "./non-existent-dir/test/foo/bar", # Non-existent config path + ".", # cwd + str( + files("mlos_bench.tests.config").joinpath("") + ), # Test configs (relative to mlos_bench/tests) + # Shouldn't be necessary since we automatically add this. + # str(files("mlos_bench.config").joinpath("")), # Stock configs + ] + } + ) def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: @@ -78,7 +82,7 @@ def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService assert os.path.exists(path) assert os.path.samefile( ConfigPersistenceService.BUILTIN_CONFIG_PATH, - os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) + os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]), ) @@ -106,8 +110,9 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No """ Check if we can successfully load a config file located relative to `config_path`. """ - tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", - ConfigSchema.TUNABLE_VALUES) + tunables_data = config_persistence_service.load_config( + "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES + ) assert tunables_data is not None assert isinstance(tunables_data, dict) assert len(tunables_data) >= 1 diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index c6dbf7c021..a09fd442fb 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -10,5 +10,5 @@ from .mock import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index 572195dcc5..78cebdf517 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -56,11 +56,12 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: json.dump(params_meta, fh_meta) script_path = local_exec_service.config_loader_service.resolve_path( - "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py") + "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py" + ) - (return_code, _stdout, stderr) = local_exec_service.local_exec([ - f"{script_path} {input_file} {meta_file} {output_file}" - ], cwd=temp_dir, env=params) + (return_code, _stdout, stderr) = local_exec_service.local_exec( + [f"{script_path} {input_file} {meta_file} {output_file}"], cwd=temp_dir, env=params + ) assert stderr.strip() == "" assert return_code == 0 diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index bd5b3b7d7f..c9dbecd93c 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -24,25 +24,27 @@ def test_split_cmdline() -> None: """ Test splitting a commandline into subcommands. """ - cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" + cmdline = ( + ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" + ) assert list(split_cmdline(cmdline)) == [ - ['.', 'env.sh'], - ['&&'], - ['('], - ['echo', 'hello'], - ['&&'], - ['echo', 'world'], - ['|'], - ['tee'], - ['>'], - ['/tmp/test'], - ['||'], - ['echo', 'foo'], - ['&&'], - ['echo', '$var'], - [';'], - ['true'], - [')'], + [".", "env.sh"], + ["&&"], + ["("], + ["echo", "hello"], + ["&&"], + ["echo", "world"], + ["|"], + ["tee"], + [">"], + ["/tmp/test"], + ["||"], + ["echo", "foo"], + ["&&"], + ["echo", "$var"], + [";"], + ["true"], + [")"], ] @@ -67,7 +69,10 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: expected_cmdline = f". env.sh && {script_abspath} --input foo" subcmds_tokens = split_cmdline(orig_cmdline) # pylint: disable=protected-access - subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] + subcmds_tokens = [ + local_exec_service._resolve_cmdline_script_path(subcmd_tokens) + for subcmd_tokens in subcmds_tokens + ] cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] expanded_cmdline = " ".join(cmdline_tokens) assert expanded_cmdline == expected_cmdline @@ -89,10 +94,7 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: Run a multiline script locally and check the results. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo hello", - "echo world" - ]) + (return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello", "echo world"]) assert return_code == 0 assert stdout.strip().split() == ["hello", "world"] assert stderr.strip() == "" @@ -103,12 +105,12 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: Run a multiline script locally and pass the environment variables to it. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - r"echo $var", # Unix shell - r"echo %var%" # Windows cmd - ], env={"var": "VALUE", "int_var": 10}) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd + env={"var": "VALUE", "int_var": 10}, + ) assert return_code == 0 - if sys.platform == 'win32': + if sys.platform == "win32": assert stdout.strip().split() == ["$var", "VALUE"] else: assert stdout.strip().split() == ["VALUE", "%var%"] @@ -121,23 +123,26 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """ with local_exec_service.temp_dir_context() as temp_dir: - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows - "echo '111,222' >> output.csv", - "echo '333,444' >> output.csv", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows + "echo '111,222' >> output.csv", + "echo '333,444' >> output.csv", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) - if sys.platform == 'win32': + if sys.platform == "win32": # Workaround for Python's subprocess module on Windows adding a # space inbetween the col1,col2 arg and the redirect symbol which # cmd poorly interprets as being part of the original string arg. # Without this, we get "col2 " as the second column name. - data.rename(str.rstrip, axis='columns', inplace=True) + data.rename(str.rstrip, axis="columns", inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444]) @@ -152,10 +157,13 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input: fh_input.write("hello\n") - (return_code, stdout, stderr) = local_exec_service.local_exec([ - f"echo 'world' >> {input_file}", - f"echo 'test' >> {input_file}", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + f"echo 'world' >> {input_file}", + f"echo 'test' >> {input_file}", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" @@ -178,11 +186,13 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N """ Try to run a series of commands, one of which fails, and abort early. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code != 0 assert stdout.strip() == "hello" @@ -192,11 +202,13 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No Try to run a series of commands, one of which fails, but let it pass. """ local_exec_service.abort_on_error = False - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code == 0 assert stdout.splitlines() == [ "hello", @@ -214,13 +226,17 @@ def test_temp_dir_path_expansion() -> None: # the fact. with tempfile.TemporaryDirectory() as temp_dir: global_config = { - "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" + "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" } config = { # The temp_dir for the LocalExecService should get expanded via workdir global config. "temp_dir": "$workdir/temp", } - local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService()) + local_exec_service = LocalExecService( + config, global_config, parent=ConfigPersistenceService() + ) # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) - assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True) + assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( + temp_dir, "temp", abs_path=True + ) diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index eede9383bc..9164da60df 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -9,5 +9,5 @@ from .mock_local_exec_service import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index db8f0134c4..3df89aaed9 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -35,16 +35,21 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): Mock methods for LocalExecService testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, self.merge_methods(methods, [self.local_exec]) ) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: return (0, "", "") diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 835738015b..4ef38ab440 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -28,19 +28,24 @@ class MockServiceBase(Service, SupportsSomeMethod): """A base service class for testing.""" def __init__( - self, - config: Optional[dict] = None, - global_config: Optional[dict] = None, - parent: Optional[Service] = None, - methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None: + self, + config: Optional[dict] = None, + global_config: Optional[dict] = None, + parent: Optional[Service] = None, + methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None, + ) -> None: super().__init__( config, global_config, parent, - self.merge_methods(methods, [ - self.some_method, - self.some_other_method, - ])) + self.merge_methods( + methods, + [ + self.some_method, + self.some_other_method, + ], + ), + ) def some_method(self) -> str: """some_method""" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index e8a87ab684..df3fb69c53 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -12,7 +12,7 @@ from .mock.mock_vm_service import MockVMService __all__ = [ - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index c6475e6936..d451370b63 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -18,7 +18,9 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_file( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" @@ -26,8 +28,9 @@ def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fil local_path = f"{local_folder}/{filename}" mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ - patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client: mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -47,38 +50,41 @@ def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_folder", "is_directory": True}, - {"name": "a_file_1.csv", "is_directory": False}, - ]) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_folder", "is_directory": True}, + {"name": "a_file_1.csv", "is_directory": False}, + ] + ), ), f"{remote_folder}/a_folder": Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_file_2.csv", "is_directory": False}, - ]) - ), - f"{remote_folder}/a_file_1.csv": Mock( - exists=Mock(return_value=False) - ), - f"{remote_folder}/a_folder/a_file_2.csv": Mock( - exists=Mock(return_value=False) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_file_2.csv", "is_directory": False}, + ] + ), ), + f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)), + f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)), } @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive(mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_non_recursive( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] @@ -87,47 +93,63 @@ def test_download_folder_non_recursive(mock_makedirs: MagicMock, mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - ], any_order=True) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_recursive( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] azure_fileshare.download(config, remote_folder, local_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_upload_file( + mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False config: dict = {} @@ -143,6 +165,7 @@ def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshar class MyDirEntry: # pylint: disable=too-few-public-methods """Dummy class for os.DirEntry""" + def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir @@ -186,17 +209,19 @@ def process_paths(input_path: str) -> str: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_non_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: @@ -208,23 +233,28 @@ def test_upload_directory_non_recursive(mock_scandir: MagicMock, @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index d6d55d3975..af239a158e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -18,16 +18,20 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_network_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_network_service: AzureNetworkService) -> None: +def test_wait_network_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_network_service: AzureNetworkService, +) -> None: """ Test retries of the network deployment operation. """ @@ -35,8 +39,12 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock, # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -49,30 +57,37 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("deprovision_network", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), # These should succeed since we set ignore_errors=True by default (401, Status.SUCCEEDED), (404, Status.SUCCEEDED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_network_operation_status(mock_requests: MagicMock, - azure_network_service: AzureNetworkService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_network_operation_status( + mock_requests: MagicMock, + azure_network_service: AzureNetworkService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """ Test network operation status. """ @@ -89,22 +104,30 @@ def test_network_operation_status(mock_requests: MagicMock, @pytest.fixture -def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: +def test_azure_network_service_no_deployment_template( + azure_auth_service: AzureAuthService, +) -> None: """ Tests creating a network services without a deployment template (should fail). """ with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - # "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + # "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 1d84d73cab..fc72131c0c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -19,16 +19,20 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_host_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_host_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """ Test retries of the host deployment operation. """ @@ -36,8 +40,12 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -50,7 +58,8 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status @@ -75,8 +84,14 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut } azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) assert azure_vm_service.deploy_params["location"] == global_config["location"] - assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' - assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + assert ( + azure_vm_service.deploy_params["vmMeta"] + == f'{global_config["vmName"]}-{global_config["location"]}' + ) + assert ( + azure_vm_service.deploy_params["vmNsg"] + == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + ) def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: @@ -98,14 +113,15 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N } with pytest.raises(ValueError): config_with_custom_data = deepcopy(config) - config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] + config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) - assert azure_vm_service.deploy_params['customData'] + assert azure_vm_service.deploy_params["customData"] @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("start_host", True), ("stop_host", True), ("shutdown", True), @@ -113,22 +129,27 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N ("deallocate_host", True), ("restart_host", True), ("reboot", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_vm_operation_status(mock_requests: MagicMock, - azure_vm_service: AzureVMService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_vm_operation_status( + mock_requests: MagicMock, + azure_vm_service: AzureVMService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """ Test VM operation status. """ @@ -145,12 +166,14 @@ def test_vm_operation_status(mock_requests: MagicMock, @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("provision_host", True), - ]) -def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, - operation_name: str, - accepts_params: bool) -> None: + ], +) +def test_vm_operation_invalid( + azure_vm_service_remote_exec_only: AzureVMService, operation_name: str, accepts_params: bool +) -> None: """ Test VM operation status for an incomplete service config. """ @@ -161,8 +184,9 @@ def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_ready( + mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService +) -> None: """ Test waiting for the completion of the remote VM operation. """ @@ -183,23 +207,20 @@ def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, status, _ = azure_vm_service.wait_host_operation(params) - assert (async_url, ) == mock_session.return_value.get.call_args[0] - assert (retry_after, ) == mock_sleep.call_args[0] + assert (async_url,) == mock_session.return_value.get.call_args[0] + assert (retry_after,) == mock_sleep.call_args[0] assert status.is_succeeded() @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_timeout(mock_session: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_timeout( + mock_session: MagicMock, azure_vm_service: AzureVMService +) -> None: """ Test the time out of the remote VM operation. """ # Mock response header - params = { - "asyncResultsUrl": "DUMMY_ASYNC_URL", - "vmName": "test-vm", - "pollInterval": 1 - } + params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1} mock_status_response = MagicMock(status_code=200) mock_status_response.json.return_value = { @@ -212,16 +233,20 @@ def test_wait_vm_operation_timeout(mock_session: MagicMock, @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_vm_operation_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """ Test the retries of the remote VM operation. """ @@ -229,8 +254,12 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), ] @@ -241,20 +270,27 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, "requestTotalRetries": total_retries, "asyncResultsUrl": "https://DUMMY_ASYNC_URL", "vmName": "test-vm", - }) + } + ) assert status == operation_status @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, - http_status_code: int, operation_status: Status) -> None: +def test_remote_exec_status( + mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService, + http_status_code: int, + operation_status: Status, +) -> None: """ Test waiting for completion of the remote execution on Azure. """ @@ -262,19 +298,24 @@ def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_ex mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) + status, _ = azure_vm_service_remote_exec_only.remote_exec( + script, config={"vmName": "test-vm"}, env_params={} + ) assert status == operation_status @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output(mock_requests: MagicMock, - azure_vm_service_remote_exec_only: AzureVMService) -> None: +def test_remote_exec_headers_output( + mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService +) -> None: """ Check if HTTP headers from the remote execution on Azure are correct. """ @@ -284,18 +325,22 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, mock_response = MagicMock() mock_response.status_code = 202 - mock_response.headers = { - "Azure-AsyncOperation": async_url_value - } - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.headers = {"Azure-AsyncOperation": async_url_value} + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ - "param_1": 123, - "param_2": "abc", - }) + _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( + script, + config={"vmName": "test-vm"}, + env_params={ + "param_1": 123, + "param_2": "abc", + }, + ) assert async_url_key in cmd_output assert cmd_output[async_url_key] == async_url_value @@ -303,15 +348,13 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, assert mock_requests.post.call_args[1]["json"] == { "commandId": "RunShellScript", "script": script, - "parameters": [ - {"name": "param_1", "value": 123}, - {"name": "param_2", "value": "abc"} - ] + "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}], } @pytest.mark.parametrize( - ("operation_status", "wait_output", "results_output"), [ + ("operation_status", "wait_output", "results_output"), + [ ( Status.SUCCEEDED, { @@ -323,13 +366,18 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, } } }, - {"stdout": "DUMMY_STDOUT_STDERR"} + {"stdout": "DUMMY_STDOUT_STDERR"}, ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), - ]) -def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, - wait_output: dict, results_output: dict) -> None: + ], +) +def test_get_remote_exec_results( + azure_vm_service_remote_exec_only: AzureVMService, + operation_status: Status, + wait_output: dict, + results_output: dict, +) -> None: """ Test getting the results of the remote execution on Azure. """ diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 2794bb01cf..6a2d62267b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -30,8 +30,9 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture -def azure_auth_service(config_persistence_service: ConfigPersistenceService, - monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: +def azure_auth_service( + config_persistence_service: ConfigPersistenceService, monkeypatch: pytest.MonkeyPatch +) -> AzureAuthService: """ Creates a dummy AzureAuthService for tests that require it. """ @@ -45,19 +46,23 @@ def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkS """ Creates a dummy Azure VM service for tests that require it. """ - return AzureNetworkService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureNetworkService( + config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VNET", - "vnetName": "test-vnet", # Should come from the upper-level config - }, parent=azure_auth_service) + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VNET", + "vnetName": "test-vnet", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture @@ -65,19 +70,23 @@ def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: """ Creates a dummy Azure VM service for tests that require it. """ - return AzureVMService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureVMService( + config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VM", + "vmName": "test-vm", # Should come from the upper-level config }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VM", - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + parent=azure_auth_service, + ) @pytest.fixture @@ -85,14 +94,18 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A """ Creates a dummy Azure VM service with no deployment template. """ - return AzureVMService(config={ - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "pollInterval": 1, - "pollTimeout": 2, - }, global_config={ - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + return AzureVMService( + config={ + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "vmName": "test-vm", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture @@ -101,8 +114,12 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu Creates a dummy AzureFileShareService for tests that require it. """ with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): - return AzureFileShareService(config={ - "storageAccountName": "TEST_ACCOUNT_NAME", - "storageFileShareName": "TEST_FS_NAME", - "storageAccountKey": "TEST_ACCOUNT_KEY" - }, global_config={}, parent=config_persistence_service) + return AzureFileShareService( + config={ + "storageAccountName": "TEST_ACCOUNT_NAME", + "storageFileShareName": "TEST_FS_NAME", + "storageAccountKey": "TEST_ACCOUNT_KEY", + }, + global_config={}, + parent=config_persistence_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index b9474f0709..fb1c4ee39b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -20,16 +20,24 @@ class MockAuthService(Service, SupportsAuth): A collection Service functions for mocking authentication ops. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) def get_access_token(self) -> str: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 1a026966a8..79f8c608c2 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -21,21 +21,30 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): A collection Service functions for mocking file share ops. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) self._upload: List[Tuple[str, str]] = [] self._download: List[Tuple[str, str]] = [] - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: self._upload.append((local_path, remote_path)) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index e6169d9f93..6bf9fc8d05 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -20,10 +20,13 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): Mock Network service for testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock network services proxy. @@ -38,13 +41,19 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsNetworkProvisioning: - "provision_network", - "deprovision_network", - "wait_network_deployment", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsNetworkProvisioning: + "provision_network", + "deprovision_network", + "wait_network_deployment", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index ee99251c64..38d759f53c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -18,10 +18,13 @@ class MockRemoteExecService(Service, SupportsRemoteExec): Mock remote script execution service. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock remote exec service. @@ -36,9 +39,14 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - "remote_exec": mock_operation, - "get_remote_exec_results": mock_operation, - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + "remote_exec": mock_operation, + "get_remote_exec_results": mock_operation, + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index a44edaf080..3ae13cf6a6 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -20,10 +20,13 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, Supports Mock VM service for testing. """ - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock VM services proxy. @@ -38,23 +41,29 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsHostProvisioning: - "wait_host_deployment", - "provision_host", - "deprovision_host", - "deallocate_host", - # SupportsHostOps: - "start_host", - "stop_host", - "restart_host", - "wait_host_operation", - # SupportsOsOps: - "shutdown", - "reboot", - "wait_os_operation", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsHostProvisioning: + "wait_host_deployment", + "provision_host", + "deprovision_host", + "deallocate_host", + # SupportsHostOps: + "start_host", + "stop_host", + "restart_host", + "wait_host_operation", + # SupportsOsOps: + "shutdown", + "reboot", + "wait_os_operation", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index e0060d8047..16c88dc791 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -17,9 +17,9 @@ # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 -SSH_TEST_SERVER_NAME = 'ssh-server' -ALT_TEST_SERVER_NAME = 'alt-server' -REBOOT_TEST_SERVER_NAME = 'reboot-server' +SSH_TEST_SERVER_NAME = "ssh-server" +ALT_TEST_SERVER_NAME = "alt-server" +REBOOT_TEST_SERVER_NAME = "reboot-server" @dataclass @@ -42,8 +42,12 @@ def get_port(self, uncached: bool = False) -> int: Note: this value can change when the service restarts so we can't rely on the DockerServices. """ if self._port is None or uncached: - port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", - shell=True, check=True, capture_output=True) + port_cmd = run( + f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", + shell=True, + check=True, + capture_output=True, + ) self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) return self._port diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 6f05fe953b..1d9f570fdf 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -30,26 +30,28 @@ # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = 'host.docker.internal' +HOST_DOCKER_NAME = "host.docker.internal" @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): # On Linux, if we're running in a docker container, we can use the # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME # Docker (Desktop) for Windows (WSL2) uses a special networking magic # to refer to the host machine as `localhost` when exposing ports. # In all other cases, assume we're executing directly inside conda on the host. - return 'localhost' + return "localhost" @pytest.fixture(scope="session") -def ssh_test_server(ssh_test_server_hostname: str, - docker_compose_project_name: str, - locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: +def ssh_test_server( + ssh_test_server_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices, +) -> Generator[SshTestServerInfo, None, None]: """ Fixture for getting the ssh test server services setup via docker-compose using pytest-docker. @@ -66,23 +68,35 @@ def ssh_test_server(ssh_test_server_hostname: str, compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, hostname=ssh_test_server_hostname, - username='root', - id_rsa_path=id_rsa_file.name) - wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) + username="root", + id_rsa_path=id_rsa_file.name, + ) + wait_docker_service_socket( + locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port() + ) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" - cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) + cmd = run( + docker_cp_cmd.split(), + check=True, + cwd=os.path.dirname(__file__), + capture_output=True, + text=True, + ) if cmd.returncode != 0: - raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " - + f"[return={cmd.returncode}]: {str(cmd.stderr)}") + raise RuntimeError( + f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " + + f"[return={cmd.returncode}]: {str(cmd.stderr)}" + ) os.chmod(id_rsa_file.name, 0o600) yield ssh_test_server_info # NamedTempFile deleted on context exit @pytest.fixture(scope="session") -def alt_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def alt_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -95,14 +109,18 @@ def alt_test_server(ssh_test_server: SshTestServerInfo, service_name=ALT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port() + ) return alt_test_server_info @pytest.fixture(scope="session") -def reboot_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def reboot_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -115,8 +133,13 @@ def reboot_test_server(ssh_test_server: SshTestServerInfo, service_name=REBOOT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, + reboot_test_server_info.hostname, + reboot_test_server_info.get_port(), + ) return reboot_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index f2bbbe4b8a..c77c57def8 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -52,8 +52,9 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker -def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_single_file( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -66,7 +67,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, lines = [line + "\n" for line in lines] # 1. Write a local file and upload it. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() temp_file.close() @@ -78,7 +79,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, ) # 2. Download the remote file and compare the contents. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.close() ssh_fileshare_service.download( params=config, @@ -86,14 +87,15 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, local_path=temp_file.name, ) # Download will replace the inode at that name, so we need to reopen the file. - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == lines @requires_docker -def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_recursive( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -113,14 +115,16 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, "bar", ], } - files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} + files_lines = { + path: [line + "\n" for line in lines] for (path, lines) in files_lines.items() + } with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: # Setup the directory structure. - for (file_path, lines) in files_lines.items(): + for file_path, lines in files_lines.items(): path = Path(tempdir1, file_path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode='w+t', encoding='utf-8') as temp_file: + with open(path, mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() assert os.path.getsize(path) > 0 @@ -147,15 +151,16 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, @requires_docker -def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_download_file_dne( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() canary_str = "canary" - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines([canary_str]) temp_file.flush() temp_file.close() @@ -166,20 +171,22 @@ def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, remote_path="/tmp/file-dne.txt", local_path=temp_file.name, ) - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == [canary_str] @requires_docker -def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_upload_file_dne( + ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService, +) -> None: """Test the SshFileShareService single file upload that doesn't exist.""" with ssh_host_service, ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() - path = '/tmp/upload-file-src-dne.txt' + path = "/tmp/upload-file-src-dne.txt" with pytest.raises(OSError): ssh_fileshare_service.upload( params=config, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 4c8e5e0c66..40a9d4ae74 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -27,9 +27,11 @@ @requires_docker -def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, - alt_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_remote_exec( + ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """ Test the SshHostService remote_exec. @@ -42,7 +44,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) + connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( + connection_id + ) assert connection_client is None (status, results_info) = ssh_host_service.remote_exec( @@ -57,7 +61,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, assert results["stdout"].strip() == SSH_TEST_SERVER_NAME # Check that the client caching is behaving as expected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection is not None assert connection._username == ssh_test_server.username assert connection._host == ssh_test_server.hostname @@ -91,13 +97,15 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) - assert status.is_failed() # should retain exit code from "false" + assert status.is_failed() # should retain exit code from "false" stdout = str(results["stdout"]) assert stdout.splitlines() == [ "BAR=bar", "UNUSED=", ] - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection._local_port == local_port # Close the connection (gracefully) @@ -114,7 +122,7 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, config=config, # Also test interacting with environment_variables. env_params={ - 'FOO': 'foo', + "FOO": "foo", }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) @@ -127,17 +135,21 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, "BAZ=", ] # Make sure it looks like we reconnected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection._local_port != local_port # Make sure the cache is cleaned up on context exit. assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 -def check_ssh_service_reboot(docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - graceful: bool) -> None: +def check_ssh_service_reboot( + docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool, +) -> None: """ Check the SshHostService reboot operation. """ @@ -148,11 +160,7 @@ def check_ssh_service_reboot(docker_services: DockerServices, with ssh_host_service: reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) (status, results_info) = ssh_host_service.remote_exec( - script=[ - 'echo "sleeping..."', - 'sleep 30', - 'echo "should not reach this point"' - ], + script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'], config=reboot_test_srv_ssh_svc_conf, env_params={}, ) @@ -161,8 +169,9 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # Now try to restart the server. - (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf, - force=not graceful) + (status, reboot_results_info) = ssh_host_service.reboot( + params=reboot_test_srv_ssh_svc_conf, force=not graceful + ) assert status.is_pending() (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) @@ -183,19 +192,34 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # try to reconnect and see if the port changed try: - run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False) + run_res = run( + "docker ps | grep mlos_bench-test- | grep reboot", + shell=True, + capture_output=True, + check=False, + ) print(run_res.stdout.decode()) print(run_res.stderr.decode()) - reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True) - if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]: + reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config( + uncached=True + ) + if ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ): break except CalledProcessError as ex: _LOG.info("Failed to check port for reboot test server: %s", ex) - assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"] + assert ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ) - wait_docker_service_socket(docker_services, - reboot_test_server.hostname, - reboot_test_srv_ssh_svc_conf_new["ssh_port"]) + wait_docker_service_socket( + docker_services, + reboot_test_server.hostname, + reboot_test_srv_ssh_svc_conf_new["ssh_port"], + ) (status, results_info) = ssh_host_service.remote_exec( script=["hostname"], @@ -208,12 +232,18 @@ def check_ssh_service_reboot(docker_services: DockerServices, @requires_docker -def test_ssh_service_reboot(locked_docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_reboot( + locked_docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """ Test the SshHostService reboot operation. """ # Grouped together to avoid parallel runner interactions. - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=True + ) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=False + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index 7bee929fea..ee9f310510 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -35,7 +35,9 @@ # We replaced pytest-lazy-fixture with pytest-lazy-fixtures: # https://github.com/TvoroG/pytest-lazy-fixture/issues/65 if version("pytest-lazy-fixture"): - raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.") + raise UserWarning( + "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it." + ) except PackageNotFoundError: # OK: pytest-lazy-fixture not installed pass @@ -43,12 +45,14 @@ @requires_docker @requires_ssh -@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ - (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), - (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), -]) -def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, - server_name: str) -> None: +@pytest.mark.parametrize( + ["ssh_test_server_info", "server_name"], + [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), + ], +) +def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, server_name: str) -> None: """Check for the pytest-docker ssh test infra.""" assert ssh_test_server_info.service_name == server_name @@ -57,17 +61,18 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, local_port = ssh_test_server_info.get_port() assert check_socket(ip_addr, local_port) - ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ - + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + ssh_cmd = ( + "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " + f"-p {local_port} {ssh_test_server_info.hostname} hostname" - cmd = run(ssh_cmd.split(), - capture_output=True, - text=True, - check=True) + ) + cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True) assert cmd.stdout.strip() == server_name -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. @@ -100,17 +105,23 @@ def test_ssh_service_context_handler() -> None: with ssh_fileshare_service: assert ssh_fileshare_service._in_context assert ssh_host_service._in_context - assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ - is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + SshService._EVENT_LOOP_CONTEXT._event_loop_thread + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread - assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ - is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + ) + assert ( + SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + ) assert not ssh_fileshare_service._in_context # And that instance should be unusable after we are outside the context. - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo')) + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result="foo")) raise ValueError(f"Future should not have been available to wait on {future.result()}") # The background thread should remain running since we have another context still open. @@ -118,6 +129,6 @@ def test_ssh_service_context_handler() -> None: assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 2c16df65c4..20320042ee 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -19,7 +19,9 @@ mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials -mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +mixed_numerics_exp_storage_with_trials = ( + sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +) exp_data = sql_storage_fixtures.exp_data exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 8159043be1..685e92f7f9 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -22,23 +22,32 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) assert exp.objectives == exp_storage.opt_targets -def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: +def test_exp_data_root_env_config( + exp_storage: Storage.Experiment, exp_data: ExperimentData +) -> None: """Tests the root_env_config property of ExperimentData""" # pylint: disable=protected-access - assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) + assert exp_data.root_env_config == ( + exp_storage._root_env_config, + exp_storage._git_repo, + exp_storage._git_commit, + ) -def test_exp_trial_data_objectives(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_data_objectives( + storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """ Start a new trial and check the storage for the trial data. """ - trial_opt_new = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "some-other-target", - "opt_direction": "max", - }) + trial_opt_new = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "some-other-target", + "opt_direction": "max", + }, + ) assert trial_opt_new.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_new.trial_id, @@ -46,10 +55,13 @@ def test_exp_trial_data_objectives(storage: Storage, "opt_direction": "max", } - trial_opt_old = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "back-compat", - # "opt_direction": "max", # missing - }) + trial_opt_old = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "back-compat", + # "opt_direction": "max", # missing + }, + ) assert trial_opt_old.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_old.trial_id, @@ -74,9 +86,14 @@ def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGr assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT assert len(results_df["trial_id"].unique()) == expected_trials_count obj_target = next(iter(exp_data.objectives)) - assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count + assert ( + len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count + ) (tunable, _covariant_group) = next(iter(tunable_groups)) - assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count + assert ( + len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) + == expected_trials_count + ) def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: @@ -116,13 +133,15 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config_trial_group.tunable_config_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [ + config_trial_group.tunable_config_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT + 1)) # And the tunable_config_trial_group_id should also match the minimum trial_id. - assert [config_trial_group.tunable_config_trial_group_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) + assert [ + config_trial_group.tunable_config_trial_group_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: @@ -130,9 +149,9 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config.tunable_config_id - for config in exp_data.tunable_configs.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list( + range(1, CONFIG_COUNT + 1) + ) def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index d0a5edc694..292996db4f 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -37,9 +37,9 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start a trial and check that it is pending. """ @@ -50,14 +50,14 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_many(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_many( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start THREE trials and check that both are pending. """ - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) - config2 = tunable_groups.copy().assign({'idle': 'noidle'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) + config2 = tunable_groups.copy().assign({"idle": "noidle"}) trial_ids = { exp_storage.new_trial(config1).trial_id, exp_storage.new_trial(config2).trial_id, @@ -72,9 +72,9 @@ def test_exp_trial_pending_many(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_fail( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start a trial, fail it, and and check that it is NOT pending. """ @@ -85,9 +85,9 @@ def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_success(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_success( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start a trial, finish it successfully, and and check that it is NOT pending. """ @@ -98,9 +98,9 @@ def test_exp_trial_success(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_categ(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_categ( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Update the trial with multiple metrics, some of which are categorical. """ @@ -108,21 +108,23 @@ def test_exp_trial_update_categ(exp_storage: Storage.Experiment, trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) assert exp_storage.load() == ( [trial.trial_id], - [{ - 'idle': 'halt', - 'kernel_sched_latency_ns': '2000000', - 'kernel_sched_migration_cost_ns': '-1', - 'vmSize': 'Standard_B4ms' - }], + [ + { + "idle": "halt", + "kernel_sched_latency_ns": "2000000", + "kernel_sched_migration_cost_ns": "-1", + "vmSize": "Standard_B4ms", + } + ], [{"score": "99.9", "benchmark": "test"}], - [Status.SUCCEEDED] + [Status.SUCCEEDED], ) @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_twice( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Update the trial status twice and receive an error. """ @@ -133,9 +135,9 @@ def test_exp_trial_update_twice(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_3(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_3( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. Check that one is still pending another one can be loaded into the optimizer. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 7e346a5ccc..f9072a2b8d 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -36,7 +36,7 @@ def storage() -> SqlStorage: "drivername": "sqlite", "database": ":memory:", # "database": "mlos_bench.pytest.db", - } + }, ) @@ -106,7 +106,9 @@ def mixed_numerics_exp_storage( assert not exp._in_context -def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: +def _dummy_run_exp( + exp: SqlStorage.Experiment, tunable_name: Optional[str] +) -> SqlStorage.Experiment: """ Generates data by doing a simulated run of the given experiment. """ @@ -119,24 +121,30 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S (tunable_min, tunable_max) = tunable.range tunable_range = tunable_max - tunable_min rand_seed(SEED) - opt = MockOptimizer(tunables=exp.tunables, config={ - "seed": SEED, - # This should be the default, so we leave it omitted for now to test the default. - # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) - # "start_with_defaults": True, - }) + opt = MockOptimizer( + tunables=exp.tunables, + config={ + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) + # "start_with_defaults": True, + }, + ) assert opt.start_with_defaults for config_i in range(CONFIG_COUNT): tunables = opt.suggest() for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT): - trial = exp.new_trial(tunables=tunables.copy(), config={ - "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(exp.opt_targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) + trial = exp.new_trial( + tunables=tunables.copy(), + config={ + "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(exp.opt_targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) if exp.tunables: assert trial.tunable_config_id == config_i + 1 else: @@ -147,14 +155,23 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S else: tunable_value_norm = 0 timestamp = datetime.now(UTC) - trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[ - (timestamp, "some-metric", tunable_value_norm + random() / 100), - ]) - trial.update(Status.SUCCEEDED, timestamp, metrics={ - # Give some variance on the score. - # And some influence from the tunable value. - "score": tunable_value_norm + random() / 100 - }) + trial.update_telemetry( + status=Status.RUNNING, + timestamp=timestamp, + metrics=[ + (timestamp, "some-metric", tunable_value_norm + random() / 100), + ], + ) + trial.update( + Status.SUCCEEDED, + timestamp, + metrics={ + # Give some variance on the score. + # And some influence from the tunable value. + "score": tunable_value_norm + + random() / 100 + }, + ) return exp @@ -167,7 +184,9 @@ def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Ex @pytest.fixture -def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def exp_no_tunables_storage_with_trials( + exp_no_tunables_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -176,7 +195,9 @@ def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Expe @pytest.fixture -def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def mixed_numerics_exp_storage_with_trials( + mixed_numerics_exp_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -185,7 +206,9 @@ def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorag @pytest.fixture -def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_data( + storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData. """ @@ -193,7 +216,9 @@ def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment @pytest.fixture -def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_no_tunables_data( + storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData with no tunable configs. """ @@ -201,7 +226,9 @@ def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trial @pytest.fixture -def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def mixed_numerics_exp_data( + storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """ Test fixture for ExperimentData with mixed numerical tunable types. """ diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index ba965ed3c6..088daca84a 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -13,8 +13,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """ Schedule a trial and check that it is pending and has the right configuration. """ @@ -31,13 +30,12 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, } -def test_exp_trial_configs(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """ Start multiple trials with two different configs and check that we store only two config objects in the DB. """ - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) trials1 = [ exp_storage.new_trial(config1), exp_storage.new_trial(config1), @@ -46,7 +44,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, assert trials1[0].tunable_config_id == trials1[1].tunable_config_id assert trials1[0].tunable_config_id == trials1[2].tunable_config_id - config2 = tunable_groups.copy().assign({'idle': 'halt'}) + config2 = tunable_groups.copy().assign({"idle": "halt"}) trials2 = [ exp_storage.new_trial(config2), exp_storage.new_trial(config2), diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index 04f4f18ae3..debd983cf0 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -22,8 +22,7 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: return set(t.trial_id for t in trials) -def test_schedule_trial(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """ Schedule several trials for future execution and retrieve them later at certain timestamps. """ @@ -44,16 +43,14 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: - pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) + pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, } # Get trials scheduled to run within the next 1 hour: - pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) + pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -62,7 +59,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -84,7 +82,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_1h.trial_id, trial_2h.trial_id, @@ -92,7 +91,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run OR running within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -114,7 +114,9 @@ def test_schedule_trial(exp_storage: Storage.Experiment, assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] # Get only trials completed after trial_now2: - (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) + (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( + last_trial_id=trial_now2.trial_id + ) assert trial_ids == [trial_1h.trial_id] assert len(trial_configs) == len(trial_scores) == 1 assert trial_status == [Status.SUCCEEDED] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index 855c6cd861..449b564395 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -31,18 +31,21 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st """ timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) - return sorted([ - (timestamp1, "cpu_load", 10.1), - (timestamp1, "memory", 20), - (timestamp1, "setup", "prod"), - (timestamp2, "cpu_load", 30.1), - (timestamp2, "memory", 40), - (timestamp2, "setup", "prod"), - ]) + return sorted( + [ + (timestamp1, "cpu_load", 10.1), + (timestamp1, "memory", 20), + (timestamp1, "setup", "prod"), + (timestamp2, "cpu_load", 30.1), + (timestamp2, "memory", 40), + (timestamp2, "setup", "prod"), + ] + ) -def _telemetry_str(data: List[Tuple[datetime, str, Any]] - ) -> List[Tuple[datetime, str, Optional[str]]]: +def _telemetry_str( + data: List[Tuple[datetime, str, Any]] +) -> List[Tuple[datetime, str, Optional[str]]]: """ Convert telemetry values to strings. """ @@ -51,10 +54,12 @@ def _telemetry_str(data: List[Tuple[datetime, str, Any]] @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry( + storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """ Make sure update_telemetry() and load_telemetry() methods work. """ @@ -73,9 +78,11 @@ def test_update_telemetry(storage: Storage, @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry_twice( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """ Make sure update_telemetry() call is idempotent. """ diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 3b57222822..251c50b241 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -10,8 +10,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_trial_data_tunable_config_data(exp_data: ExperimentData, - tunable_groups: TunableGroups) -> None: +def test_trial_data_tunable_config_data( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """ Check expected return values for TunableConfigData. """ @@ -29,12 +30,12 @@ def test_trial_metadata(exp_data: ExperimentData) -> None: """ Check expected return values for TunableConfigData metadata. """ - assert exp_data.objectives == {'score': 'min'} - for (trial_id, trial) in exp_data.trials.items(): + assert exp_data.objectives == {"score": "min"} + for trial_id, trial in exp_data.trials.items(): assert trial.metadata_dict == { - 'opt_target_0': 'score', - 'opt_direction_0': 'min', - 'trial_number': trial_id, + "opt_target_0": "score", + "opt_direction_0": "min", + "trial_number": trial_id, } @@ -48,13 +49,13 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, - mixed_numerics_tunable_groups: TunableGroups) -> None: + mixed_numerics_exp_data: ExperimentData, mixed_numerics_tunable_groups: TunableGroups +) -> None: """ Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) config = trial.tunable_config.config_dict - for (tunable, _group) in mixed_numerics_tunable_groups: + for tunable, _group in mixed_numerics_tunable_groups: assert isinstance(config[tunable.name], tunable.dtype) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index d08b26e92d..fd57d07635 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -16,10 +16,15 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: trial_id = 1 trial = exp_data.trials[trial_id] tunable_config_trial_group = trial.tunable_config_trial_group - assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id + assert ( + tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id + ) assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id assert tunable_config_trial_group.tunable_config == trial.tunable_config - assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group + assert ( + tunable_config_trial_group + == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group + ) def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: @@ -49,7 +54,9 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) # And so on ... -def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: +def test_tunable_config_trial_group_results_df( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 expected_group_id = 4 @@ -58,9 +65,14 @@ def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable # We shouldn't have the results for the other configs, just this one. expected_count = CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_count - assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count + assert ( + len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count + ) assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 - assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count + assert ( + len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) + == expected_count + ) assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 assert len(results_df["trial_id"].unique()) == expected_count obj_target = next(iter(exp_data.objectives)) @@ -76,8 +88,14 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] trials = tunable_config_trial_group.trials assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT - assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id - for trial in trials.values()) - assert all(trial.tunable_config_id == tunable_config_id - for trial in tunable_config_trial_group.trials.values()) - assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] + assert all( + trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id + for trial in trials.values() + ) + assert all( + trial.tunable_config_id == tunable_config_id + for trial in tunable_config_trial_group.trials.values() + ) + assert ( + exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] + ) diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index fa947610da..c3acd9d243 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -24,7 +24,7 @@ ] -@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism") +@pytest.mark.skipif(sys.platform == "win32", reason="TZ environment variable is a UNIXism") @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: @@ -45,4 +45,6 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: if cmd.returncode != 0: print(cmd.stdout.decode()) print(cmd.stderr.decode()) - raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'") + raise AssertionError( + f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'" + ) diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 822547b1da..8329b51bd0 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -119,24 +119,26 @@ def mixed_numerics_tunable_groups() -> TunableGroups: tunable_groups : TunableGroups A new TunableGroups object for testing. """ - tunables = TunableGroups({ - "mix-numerics": { - "cost": 1, - "params": { - "int": { - "description": "An integer", - "type": "int", - "default": 0, - "range": [0, 100], + tunables = TunableGroups( + { + "mix-numerics": { + "cost": 1, + "params": { + "int": { + "description": "An integer", + "type": "int", + "default": 0, + "range": [0, 100], + }, + "float": { + "description": "A float", + "type": "float", + "default": 0, + "range": [0, 1], + }, }, - "float": { - "description": "A float", - "type": "float", - "default": 0, - "range": [0, 1], - }, - } - }, - }) + }, + } + ) tunables.reset() return tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index 95de20d9b8..878471b59e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -25,12 +25,15 @@ def tunable_categorical() -> Tunable: tunable : Tunable An instance of a categorical Tunable. """ - return Tunable("vmSize", { - "description": "Azure VM size", - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - }) + return Tunable( + "vmSize", + { + "description": "Azure VM size", + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + }, + ) @pytest.fixture @@ -43,13 +46,16 @@ def tunable_int() -> Tunable: tunable : Tunable An instance of an integer Tunable. """ - return Tunable("kernel_sched_migration_cost_ns", { - "description": "Cost of migrating the thread to another core", - "type": "int", - "default": 40000, - "range": [0, 500000], - "special": [-1] # Special value outside of the range - }) + return Tunable( + "kernel_sched_migration_cost_ns", + { + "description": "Cost of migrating the thread to another core", + "type": "int", + "default": 40000, + "range": [0, 500000], + "special": [-1], # Special value outside of the range + }, + ) @pytest.fixture @@ -62,9 +68,12 @@ def tunable_float() -> Tunable: tunable : Tunable An instance of a float Tunable. """ - return Tunable("chaos_monkey_prob", { - "description": "Probability of spontaneous VM shutdown", - "type": "float", - "default": 0.01, - "range": [0, 1] - }) + return Tunable( + "chaos_monkey_prob", + { + "description": "Probability of spontaneous VM shutdown", + "type": "float", + "default": 0.01, + "range": [0, 1], + }, + ) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index 0e910f3761..e8b3e6b4cc 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -38,7 +38,7 @@ def test_tunable_categorical_types() -> None: "values": ["a", "b", "c"], "default": "a", }, - } + }, } } tunable_groups = TunableGroups(tunable_params) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index 58bb0368b1..c42ae21676 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -14,6 +14,7 @@ # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py + def test_tunable_int_size_props() -> None: """Test tunable int size properties""" tunable = Tunable( @@ -22,7 +23,8 @@ def test_tunable_int_size_props() -> None: "type": "int", "range": [1, 5], "default": 3, - }) + }, + ) assert tunable.span == 4 assert tunable.cardinality == 5 expected = [1, 2, 3, 4, 5] @@ -38,7 +40,8 @@ def test_tunable_float_size_props() -> None: "type": "float", "range": [1.5, 5], "default": 3, - }) + }, + ) assert tunable.span == 3.5 assert tunable.cardinality == np.inf assert tunable.quantized_values is None @@ -53,7 +56,8 @@ def test_tunable_categorical_size_props() -> None: "type": "categorical", "values": ["a", "b", "c"], "default": "a", - }) + }, + ) with pytest.raises(AssertionError): _ = tunable.span assert tunable.cardinality == 3 @@ -66,12 +70,8 @@ def test_tunable_quantized_int_size_props() -> None: """Test quantized tunable int size properties""" tunable = Tunable( name="test", - config={ - "type": "int", - "range": [100, 1000], - "default": 100, - "quantization": 100 - }) + config={"type": "int", "range": [100, 1000], "default": 100, "quantization": 100}, + ) assert tunable.span == 900 assert tunable.cardinality == 10 expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] @@ -82,13 +82,8 @@ def test_tunable_quantized_int_size_props() -> None: def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties""" tunable = Tunable( - name="test", - config={ - "type": "float", - "range": [0, 1], - "default": 0, - "quantization": .1 - }) + name="test", config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1} + ) assert tunable.span == 1 assert tunable.cardinality == 11 expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 6a91b14016..407998b3a4 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -28,7 +28,7 @@ def test_tunable_int_name_lt(tunable_int: Tunable) -> None: Tests that the __lt__ operator works as expected. """ tunable_int_2 = tunable_int.copy() - tunable_int_2._name = "aaa" # pylint: disable=protected-access + tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int @@ -38,7 +38,8 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: """ tunable_categorical_2 = tunable_categorical.copy() new_value = [ - x for x in tunable_categorical.categories + x + for x in tunable_categorical.categories if x != tunable_categorical.category and x is not None ][0] assert tunable_categorical.category is not None @@ -59,7 +60,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_dog = Tunable( name="same-name", @@ -67,7 +68,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": [None, "doggo"], "default": None, - } + }, ) assert tunable_dog < tunable_cat @@ -82,7 +83,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_int = Tunable( name="same-name", @@ -90,7 +91,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "int", "range": [1, 3], "default": 2, - } + }, ) assert tunable_cat < tunable_int @@ -101,7 +102,7 @@ def test_tunable_lt_different_object(tunable_int: Tunable) -> None: """ assert (tunable_int < "foo") is False with pytest.raises(TypeError): - assert "foo" < tunable_int # type: ignore[operator] + assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index f2da3ba60e..980fda06a4 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -18,7 +18,7 @@ def test_tunable_name() -> None: """ with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names - Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) + Tunable(name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0}) def test_categorical_required_params() -> None: @@ -34,7 +34,7 @@ def test_categorical_required_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights() -> None: @@ -50,7 +50,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.weights == [25, 25, 50] @@ -68,7 +68,7 @@ def test_categorical_weights_wrong_count() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights_wrong_values() -> None: @@ -85,7 +85,7 @@ def test_categorical_weights_wrong_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_wrong_params() -> None: @@ -102,7 +102,7 @@ def test_categorical_wrong_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_disallow_special_values() -> None: @@ -119,7 +119,7 @@ def test_categorical_disallow_special_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_tunable_disallow_repeats() -> None: @@ -127,11 +127,14 @@ def test_categorical_tunable_disallow_repeats() -> None: Disallow duplicate values in categorical tunables. """ with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "foo"], - "default": "foo", - }) + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "foo"], + "default": "foo", + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -140,11 +143,14 @@ def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeN Disallow null values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": None, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": None, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -153,11 +159,14 @@ def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeN Disallow out of range values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": 11, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": 11, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -166,12 +175,15 @@ def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> N Disallow values param for numerical tunables. """ with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "values": ["foo", "bar"], - "default": 0, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "values": ["foo", "bar"], + "default": 0, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -188,7 +200,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -205,7 +217,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(AssertionError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -222,7 +234,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -241,7 +253,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [0] assert tunable.weights == [0.1] assert tunable.range_weight == 0.9 @@ -261,7 +273,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.quantization == 10 assert not tunable.is_log @@ -280,7 +292,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.is_log @@ -299,7 +311,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -319,7 +331,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok assert tunable.range_weight == 90 @@ -342,7 +354,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -361,7 +373,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -380,7 +392,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -398,7 +410,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -418,7 +430,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -436,7 +448,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_bad_type() -> None: @@ -452,4 +464,4 @@ def test_bad_type() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test_bad_type', config=config) + Tunable(name="test_bad_type", config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index deffcb6a46..e8817319ab 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -17,14 +17,15 @@ def test_categorical_distribution() -> None: Try to instantiate a categorical tunable with distribution specified. """ with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "baz"], - "distribution": { - "type": "uniform" + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "baz"], + "distribution": {"type": "uniform"}, + "default": "foo", }, - "default": "foo" - }) + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -32,14 +33,15 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N """ Create a numeric Tunable with explicit uniform distribution. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "uniform" + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "uniform"}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.is_numerical assert tunable.distribution == "uniform" assert not tunable.distribution_params @@ -50,18 +52,15 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No """ Create a numeric Tunable with explicit Gaussian distribution specified. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "normal", - "params": { - "mu": 0, - "sigma": 1.0 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "normal" assert tunable.distribution_params == {"mu": 0, "sigma": 1.0} @@ -71,18 +70,15 @@ def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None """ Create a numeric Tunable with explicit Beta distribution specified. """ - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "beta", - "params": { - "alpha": 2, - "beta": 5 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "beta" assert tunable.distribution_params == {"alpha": 2, "beta": 5} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index c6fb5670f0..d9b209cf4f 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -10,7 +10,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: +def test_tunable_group_indexing( + tunable_groups: TunableGroups, tunable_categorical: Tunable +) -> None: """ Check that various types of indexing work for the tunable group. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 55a485e951..186de4acfa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -14,4 +14,4 @@ def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: Check that the subgroup() method returns only a selection of tunable parameters. """ tunables = tunable_groups.subgroup(["provision"]) - assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} + assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 73e3a12caa..0dfbdd2acd 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -36,37 +36,39 @@ @pytest.mark.parametrize("param_type", ["int", "float"]) -@pytest.mark.parametrize("distr_name,distr_params", [ - ("normal", {"mu": 0.0, "sigma": 1.0}), - ("beta", {"alpha": 2, "beta": 5}), - ("uniform", {}), -]) -def test_convert_numerical_distributions(param_type: str, - distr_name: DistributionName, - distr_params: dict) -> None: +@pytest.mark.parametrize( + "distr_name,distr_params", + [ + ("normal", {"mu": 0.0, "sigma": 1.0}), + ("beta", {"alpha": 2, "beta": 5}), + ("uniform", {}), + ], +) +def test_convert_numerical_distributions( + param_type: str, distr_name: DistributionName, distr_params: dict +) -> None: """ Convert a numerical Tunable with explicit distribution to ConfigSpace. """ tunable_name = "x" - tunable_groups = TunableGroups({ - "tunable_group": { - "cost": 1, - "params": { - tunable_name: { - "type": param_type, - "range": [0, 100], - "special": [-1, 0], - "special_weights": [0.1, 0.2], - "range_weight": 0.7, - "distribution": { - "type": distr_name, - "params": distr_params - }, - "default": 0 - } + tunable_groups = TunableGroups( + { + "tunable_group": { + "cost": 1, + "params": { + tunable_name: { + "type": param_type, + "range": [0, 100], + "special": [-1, 0], + "special_weights": [0.1, 0.2], + "range_weight": 0.7, + "distribution": {"type": distr_name, "params": distr_params}, + "default": 0, + } + }, } } - }) + ) (tunable, _group) = tunable_groups.get_tunable(tunable_name) assert tunable.distribution == distr_name @@ -82,5 +84,5 @@ def test_convert_numerical_distributions(param_type: str, cs_param = space[tunable_name] assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) - for (key, val) in distr_params.items(): + for key, val in distr_params.items(): assert getattr(cs_param, key) == val diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 78e91fd25e..39bd41e282 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -38,17 +38,23 @@ def configuration_space() -> ConfigurationSpace: configuration_space : ConfigurationSpace A new ConfigurationSpace object for testing. """ - (kernel_sched_migration_cost_ns_special, - kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns") - - spaces = ConfigurationSpace(space={ - "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - "idle": ["halt", "mwait", "noidle"], - "kernel_sched_migration_cost_ns": (0, 500000), - kernel_sched_migration_cost_ns_special: [-1, 0], - kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE], - "kernel_sched_latency_ns": (0, 1000000000), - }) + (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = ( + special_param_names("kernel_sched_migration_cost_ns") + ) + + spaces = ConfigurationSpace( + space={ + "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + "idle": ["halt", "mwait", "noidle"], + "kernel_sched_migration_cost_ns": (0, 500000), + kernel_sched_migration_cost_ns_special: [-1, 0], + kernel_sched_migration_cost_ns_type: [ + TunableValueKind.SPECIAL, + TunableValueKind.RANGE, + ], + "kernel_sched_latency_ns": (0, 1000000000), + } + ) # NOTE: FLAML requires distribution to be uniform spaces["vmSize"].default_value = "Standard_B4ms" @@ -60,18 +66,25 @@ def configuration_space() -> ConfigurationSpace: spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5) spaces["kernel_sched_latency_ns"].default_value = 2000000 - spaces.add_condition(EqualsCondition( - spaces[kernel_sched_migration_cost_ns_special], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL)) - spaces.add_condition(EqualsCondition( - spaces["kernel_sched_migration_cost_ns"], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE)) + spaces.add_condition( + EqualsCondition( + spaces[kernel_sched_migration_cost_ns_special], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.SPECIAL, + ) + ) + spaces.add_condition( + EqualsCondition( + spaces["kernel_sched_migration_cost_ns"], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.RANGE, + ) + ) return spaces -def _cmp_tunable_hyperparameter_categorical( - tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_categorical(tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if categorical Tunable and ConfigSpace Hyperparameter actually match. """ @@ -81,8 +94,7 @@ def _cmp_tunable_hyperparameter_categorical( assert param.default_value == tunable.value -def _cmp_tunable_hyperparameter_numerical( - tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_numerical(tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if integer Tunable and ConfigSpace Hyperparameter actually match. """ @@ -130,12 +142,13 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) - for (tunable, _group) in tunable_groups: + for tunable, _group in tunable_groups: _CMP_FUNC[tunable.type](tunable, space) def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace +) -> None: """ Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index cbccd6bfe1..2f7790602f 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -19,12 +19,14 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: that don't exist in the TunableGroups object. """ with pytest.raises(KeyError): - tunable_groups.assign({ - "vmSize": "Standard_B2ms", - "idle": "mwait", - "UnknownParam_1": 1, - "UnknownParam_2": "invalid-value" - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2ms", + "idle": "mwait", + "UnknownParam_1": 1, + "UnknownParam_2": "invalid-value", + } + ) def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: @@ -106,7 +108,7 @@ def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: Check str to int coercion. """ tunable_int.value = "10" - assert tunable_int.value == 10 # type: ignore[comparison-overlap] + assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special @@ -115,7 +117,7 @@ def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: Check str to float coercion. """ tunable_float.value = "0.5" - assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special @@ -149,12 +151,12 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) - categorical_tunable = Tunable(name='categorical_test', config=config) + categorical_tunable = Tunable(name="categorical_test", config=config) assert categorical_tunable assert categorical_tunable.category == "foo" categorical_tunable.value = None assert categorical_tunable.value is None - assert categorical_tunable.value != 'None' + assert categorical_tunable.value != "None" assert categorical_tunable.category is None @@ -165,7 +167,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): - tunable_int.numerical_value = None # type: ignore[assignment] + tunable_int.numerical_value = None # type: ignore[assignment] def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: @@ -175,7 +177,7 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): - tunable_float.numerical_value = None # type: ignore[assignment] + tunable_float.numerical_value = None # type: ignore[assignment] def test_tunable_assign_special(tunable_int: Tunable) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index 672b16ab73..cb41f7f7d8 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -17,42 +17,44 @@ def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: tunables within each covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: - tunables_other = TunableGroups({ - "kernel": { - "cost": 1, - "params": { - "kernel_sched_latency_ns": { - "type": "int", - "default": 2000000, - "range": [0, 1000000000] + tunables_other = TunableGroups( + { + "kernel": { + "cost": 1, + "params": { + "kernel_sched_latency_ns": { + "type": "int", + "default": 2000000, + "range": [0, 1000000000], + }, + "kernel_sched_migration_cost_ns": { + "type": "int", + "default": -1, + "range": [0, 500000], + "special": [-1], + }, }, - "kernel_sched_migration_cost_ns": { - "type": "int", - "default": -1, - "range": [0, 500000], - "special": [-1] - } - } - }, - "boot": { - "cost": 300, - "params": { - "idle": { - "type": "categorical", - "default": "halt", - "values": ["halt", "mwait", "noidle"] - } - } - }, - "provision": { - "cost": 1000, - "params": { - "vmSize": { - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - } - } - }, - }) + }, + "boot": { + "cost": 300, + "params": { + "idle": { + "type": "categorical", + "default": "halt", + "values": ["halt", "mwait", "noidle"], + } + }, + }, + "provision": { + "cost": 1000, + "params": { + "vmSize": { + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + } + }, + }, + } + ) assert str(tunable_groups) == str(tunables_other) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 4191f37d89..3433f4a735 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -10,7 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups __all__ = [ - 'Tunable', - 'TunableValue', - 'TunableGroups', + "Tunable", + "TunableValue", + "TunableGroups", ] diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index fee4fd5841..797510a087 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -93,10 +93,12 @@ def __eq__(self, other: object) -> bool: return False # TODO: May need to provide logic to relax the equality check on the # tunables (e.g. "compatible" vs. "equal"). - return (self._name == other._name and - self._cost == other._cost and - self._is_updated == other._is_updated and - self._tunables == other._tunables) + return ( + self._name == other._name + and self._cost == other._cost + and self._is_updated == other._is_updated + and self._tunables == other._tunables + ) def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ @@ -234,7 +236,11 @@ def __contains__(self, tunable: Union[str, Tunable]) -> bool: def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._is_updated |= self.get_tunable(tunable).update(value) return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 1ebd70dfa4..b2a465c71a 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -107,7 +107,7 @@ def __init__(self, name: str, config: TunableDict): config : dict Python dict that represents a Tunable (e.g., deserialized from JSON) """ - if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema + if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name self._type: TunableValueTypeName = config["type"] # required @@ -202,10 +202,16 @@ def _sanity_check_numerical(self) -> None: raise ValueError(f"Number of quantization points is <= 1: {self}") if self.dtype == float: if not isinstance(self._quantization, (float, int)): - raise ValueError(f"Quantization of a float param should be a float or int: {self}") + raise ValueError( + f"Quantization of a float param should be a float or int: {self}" + ) if self._quantization <= 0: raise ValueError(f"Number of quantization points is <= 0: {self}") - if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: + if self._distribution is not None and self._distribution not in { + "uniform", + "normal", + "beta", + }: raise ValueError(f"Invalid distribution: {self}") if self._distribution_params and self._distribution is None: raise ValueError(f"Must specify the distribution: {self}") @@ -230,7 +236,9 @@ def __repr__(self) -> str: """ # TODO? Add weights, specials, quantization, distribution? if self.is_categorical: - return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" + return ( + f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" + ) return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" def __eq__(self, other: object) -> bool: @@ -251,12 +259,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Tunable): return False return bool( - self._name == other._name and - self._type == other._type and - self._current_value == other._current_value + self._name == other._name + and self._type == other._type + and self._current_value == other._current_value ) - def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements + def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements """ Compare the two Tunable objects. We mostly need this to create a canonical list of tunable objects when hashing a TunableGroup. @@ -336,18 +344,21 @@ def value(self, value: TunableValue) -> TunableValue: assert value is not None coerced_value = self.dtype(value) except Exception: - _LOG.error("Impossible conversion: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Impossible conversion: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: - _LOG.error("Loss of precision: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Loss of precision: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): - _LOG.error("Invalid assignment: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Invalid assignment: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") self._current_value = coerced_value @@ -403,10 +414,10 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: Return False if the tunable or value is categorical or None. """ return ( - isinstance(value, (float, int)) and - self.is_numerical and - self._range is not None and - bool(self._range[0] <= value <= self._range[1]) + isinstance(value, (float, int)) + and self.is_numerical + and self._range is not None + and bool(self._range[0] <= value <= self._range[1]) ) @property @@ -626,10 +637,12 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: # Be sure to return python types instead of numpy types. cardinality = self.cardinality assert isinstance(cardinality, int) - return (float(x) for x in np.linspace(start=num_range[0], - stop=num_range[1], - num=cardinality, - endpoint=True)) + return ( + float(x) + for x in np.linspace( + start=num_range[0], stop=num_range[1], num=cardinality, endpoint=True + ) + ) assert self.type == "int", f"Unhandled tunable type: {self}" return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 0bd58c8269..8fbaee878c 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -30,9 +30,11 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = ( + {} + ) # Index (Tunable id -> CovariantTunableGroup) self._tunable_groups: Dict[str, CovariantTunableGroup] = {} - for (name, group_config) in config.items(): + for name, group_config in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) def __bool__(self) -> bool: @@ -81,11 +83,15 @@ def _add_group(self, group: CovariantTunableGroup) -> None: ---------- group : CovariantTunableGroup """ - assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" + assert ( + group.name not in self._tunable_groups + ), f"Duplicate covariant tunable group name {group.name} in {self}" self._tunable_groups[group.name] = group for tunable in group.get_tunables(): if tunable.name in self._index: - raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") + raise ValueError( + f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" + ) self._index[tunable.name] = group def merge(self, tunables: "TunableGroups") -> "TunableGroups": @@ -119,8 +125,10 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": # Check that there's no overlap in the tunables. # But allow for differing current values. if not self._tunable_groups[group.name].equals_defaults(group): - raise ValueError(f"Overlapping covariant tunable group name {group.name} " + - "in {self._tunable_groups[group.name]} and {tunables}") + raise ValueError( + f"Overlapping covariant tunable group name {group.name} " + + "in {self._tunable_groups[group.name]} and {tunables}" + ) return self def __repr__(self) -> str: @@ -132,10 +140,15 @@ def __repr__(self) -> str: string : str A human-readable version of the TunableGroups. """ - return "{ " + ", ".join( - f"{group.name}::{tunable}" - for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) - for tunable in sorted(group._tunables.values())) + " }" + return ( + "{ " + + ", ".join( + f"{group.name}::{tunable}" + for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) + for tunable in sorted(group._tunables.values()) + ) + + " }" + ) def __contains__(self, tunable: Union[str, Tunable]) -> bool: """ @@ -151,13 +164,17 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: """ Update the current value of a single tunable parameter. """ # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._index[name][name] = value return self._index[name][name] @@ -232,8 +249,11 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values(self, group_names: Optional[Iterable[str]] = None, - into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: + def get_param_values( + self, + group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None, + ) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -272,8 +292,10 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: is_updated : bool True if any of the specified tunable groups has been updated, False otherwise. """ - return any(self._tunable_groups[name].is_updated() - for name in (group_names or self.get_covariant_group_names())) + return any( + self._tunable_groups[name].is_updated() + for name in (group_names or self.get_covariant_group_names()) + ) def is_defaults(self) -> bool: """ @@ -299,7 +321,7 @@ def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "Tuna self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].restore_defaults() return self @@ -317,7 +339,7 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].reset_is_updated() return self diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 531988be97..619e712497 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -71,8 +71,9 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> return dest -def merge_parameters(*, dest: dict, source: Optional[dict] = None, - required_keys: Optional[Iterable[str]] = None) -> dict: +def merge_parameters( + *, dest: dict, source: Optional[dict] = None, required_keys: Optional[Iterable[str]] = None +) -> dict: """ Merge the source config dict into the destination config. Pick from the source configs *ONLY* the keys that are already present @@ -132,8 +133,9 @@ def path_join(*args: str, abs_path: bool = False) -> str: return os.path.normpath(path).replace("\\", "/") -def prepare_class_load(config: dict, - global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load( + config: dict, global_config: Optional[Dict[str, Any]] = None +) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -155,8 +157,9 @@ def prepare_class_load(config: dict, merge_parameters(dest=class_config, source=global_config) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) + ) return (class_name, class_config) @@ -187,8 +190,9 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". -def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, - *args: Any, **kwargs: Any) -> BaseTypeVar: +def instantiate_from_config( + base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any +) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -238,7 +242,8 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s if missing_params: raise ValueError( "The following parameters must be provided in the configuration" - + f" or as command line arguments: {missing_params}") + + f" or as command line arguments: {missing_params}" + ) def get_git_info(path: str = __file__) -> Tuple[str, str, str]: @@ -257,11 +262,14 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]: """ dirname = os.path.dirname(path) git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() + ["git", "-C", dirname, "remote", "get-url", "origin"], text=True + ).strip() git_commit = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "HEAD"], text=True + ).strip() git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) return (git_repo, git_commit, rel_path.replace("\\", "/")) @@ -355,7 +363,9 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> raise ValueError(f"Invalid origin: {origin}") -def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: +def utcify_nullable_timestamp( + timestamp: Optional[datetime], *, origin: Literal["utc", "local"] +) -> Optional[datetime]: """ A nullable version of utcify_timestamp. """ @@ -367,7 +377,9 @@ def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal[ _MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) -def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: +def datetime_parser( + datetime_col: pandas.Series, *, origin: Literal["utc", "local"] +) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. @@ -401,7 +413,7 @@ def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "loca new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) assert new_datetime_col.dt.tz is not None # And convert it to UTC. - new_datetime_col = new_datetime_col.dt.tz_convert('UTC') + new_datetime_col = new_datetime_col.dt.tz_convert("UTC") if new_datetime_col.isna().any(): raise ValueError(f"Invalid date format in the data: {datetime_col}") if new_datetime_col.le(_MIN_TS).any(): diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index 96d3d2b6bf..f8acae8c02 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 27d844c35b..fc29bfbcbb 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -21,15 +21,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -47,62 +48,68 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } -extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass +extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass # Additional tools for extra functionality. - 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], - 'ssh': ['asyncssh'], - 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], - 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], - 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], - 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it. + "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], + "ssh": ["asyncssh"], + "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], + "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], + "storage-sql-postgres": ["sqlalchemy", "psycopg2"], + "storage-sql-sqlite": [ + "sqlalchemy" + ], # sqlite3 comes with python, so we don't need to install it. # Transitive extra_requires from mlos-core. - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac'], + "flaml": ["flaml[blendsearch]"], + "smac": ["smac"], } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', - 'pytest-lazy-fixtures', - 'pytest-docker', - 'fasteners', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", + "pytest-lazy-fixtures", + "pytest-docker", + "fasteners", ] setup( version=VERSION, install_requires=[ - 'mlos-core==' + VERSION, - 'requests', - 'json5', - 'jsonschema>=4.18.0', 'referencing>=0.29.1', + "mlos-core==" + VERSION, + "requests", + "json5", + "jsonschema>=4.18.0", + "referencing>=0.29.1", 'importlib_resources;python_version<"3.10"', - ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it. + ] + + extra_requires[ + "storage-sql-sqlite" + ], # NOTE: For now sqlite is a fallback storage backend, so we always install it. extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'), + **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_bench"), ) diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index 086002af62..b3e248e407 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -18,12 +18,12 @@ from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ - 'SpaceAdapterType', - 'OptimizerFactory', - 'BaseOptimizer', - 'RandomOptimizer', - 'FlamlOptimizer', - 'SmacOptimizer', + "SpaceAdapterType", + "OptimizerFactory", + "BaseOptimizer", + "RandomOptimizer", + "FlamlOptimizer", + "SmacOptimizer", ] @@ -45,7 +45,7 @@ class OptimizerType(Enum): # ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) # To address this, we add a test for complete coverage of the enum. ConcreteOptimizer = TypeVar( - 'ConcreteOptimizer', + "ConcreteOptimizer", RandomOptimizer, FlamlOptimizer, SmacOptimizer, @@ -60,13 +60,15 @@ class OptimizerFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, - optimizer_kwargs: Optional[dict] = None, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, + optimizer_kwargs: Optional[dict] = None, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteOptimizer: # type: ignore[type-var] """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. @@ -107,7 +109,7 @@ def create(*, parameter_space=parameter_space, optimization_targets=optimization_targets, space_adapter=space_adapter, - **optimizer_kwargs + **optimizer_kwargs, ) return optimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index 5f32219988..d4f59dfa52 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -12,6 +12,6 @@ from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer __all__ = [ - 'BaseBayesianOptimizer', - 'SmacOptimizer', + "BaseBayesianOptimizer", + "SmacOptimizer", ] diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 76ff0d9b3a..9d3bcabcb2 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -19,8 +19,9 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): """Abstract base class defining the interface for Bayesian optimization.""" @abstractmethod - def surrogate_predict(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). Parameters @@ -31,11 +32,12 @@ def surrogate_predict(self, *, configs: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def acquisition_function(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """Invokes the acquisition function from this Bayesian optimizer for the given configuration. Parameters @@ -46,4 +48,4 @@ def acquisition_function(self, *, configs: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 9d8d2a0347..5784a42f12 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -29,19 +29,22 @@ class SmacOptimizer(BaseBayesianOptimizer): Wrapper class for SMAC based Bayesian optimization. """ - def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - seed: Optional[int] = 0, - run_name: Optional[str] = None, - output_directory: Optional[str] = None, - max_trials: int = 100, - n_random_init: Optional[int] = None, - max_ratio: Optional[float] = None, - use_default_config: bool = False, - n_random_probability: float = 0.1): + def __init__( + self, + *, # pylint: disable=too-many-locals,too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + seed: Optional[int] = 0, + run_name: Optional[str] = None, + output_directory: Optional[str] = None, + max_trials: int = 100, + n_random_init: Optional[int] = None, + max_ratio: Optional[float] = None, + use_default_config: bool = False, + n_random_probability: float = 0.1, + ): """ Instantiate a new SMAC optimizer wrapper. @@ -124,7 +127,9 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory( + ignore_cleanup_errors=True + ) # Argument added in Python 3.10 except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -146,8 +151,12 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments seed=seed or -1, # if -1, SMAC will generate a random seed internally n_workers=1, # Use a single thread for evaluating trials ) - intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1) - config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1) + intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( + scenario, max_config_calls=1 + ) + config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( + scenario, retrain_after=1 + ) # TODO: When bulk registering prior configs to rewarm the optimizer, # there is a way to inform SMAC's initial design that we have @@ -158,27 +167,27 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments # See Also: #488 initial_design_args: Dict[str, Union[list, int, float, Scenario]] = { - 'scenario': scenario, + "scenario": scenario, # Workaround a bug in SMAC that sets a default arg to a mutable # value that can cause issues when multiple optimizers are # instantiated with the use_default_config option within the same # process that use different ConfigSpaces so that the second # receives the default config from both as an additional config. - 'additional_configs': [] + "additional_configs": [], } if n_random_init is not None: - initial_design_args['n_configs'] = n_random_init + initial_design_args["n_configs"] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - 'Number of random initial configs (%d) is ' + - 'greater than 25%% of max_trials (%d). ' + - 'Consider setting max_ratio to avoid SMAC overriding n_random_init.', + "Number of random initial configs (%d) is " + + "greater than 25%% of max_trials (%d). " + + "Consider setting max_ratio to avoid SMAC overriding n_random_init.", n_random_init, max_trials, ) if max_ratio is not None: assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0 - initial_design_args['max_ratio'] = max_ratio + initial_design_args["max_ratio"] = max_ratio # Use the default InitialDesign from SMAC. # (currently SBOL instead of LatinHypercube due to better uniformity @@ -190,7 +199,9 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments # design when generated a random_design for itself via the # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 - random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed) + random_design = ProbabilityRandomDesign( + probability=n_random_probability, seed=scenario.seed + ) self.base_optimizer = Optimizer_Smac( scenario, @@ -200,7 +211,8 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights), + scenario, objective_weights=self._objective_weights + ), overwrite=True, logging_level=False, # Use the existing logger ) @@ -241,10 +253,16 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None """ # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 - raise RuntimeError('This function should never be called.') - - def _register(self, *, configs: pd.DataFrame, - scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + raise RuntimeError("This function should never be called.") + + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -271,17 +289,22 @@ def _register(self, *, configs: pd.DataFrame, warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) # Register each trial (one-by-one) - for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()): + for config, (_i, score) in zip( + self._to_configspace_configs(configs=configs), scores.iterrows() + ): # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed)) + config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) + ) value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) self.base_optimizer.tell(info, value, save=False) # Save optimizer once we register all configs self.base_optimizer.optimizer.save() - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -310,15 +333,23 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr self.optimizer_parameter_space.check_configuration(trial.config) assert trial.config.config_space == self.optimizer_parameter_space self.trial_info_map[trial.config] = trial - config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys())) + config_df = pd.DataFrame( + [trial.config], columns=list(self.optimizer_parameter_space.keys()) + ) return config_df, None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() - def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: from smac.utils.configspace import ( convert_configurations_to_array, # pylint: disable=import-outside-toplevel ) @@ -331,16 +362,23 @@ def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataF # pylint: disable=protected-access if len(self._observations) <= self.base_optimizer._initial_design._n_configs: raise RuntimeError( - 'Surrogate model can make predictions *only* after all initial points have been evaluated ' + - f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}') + "Surrogate model can make predictions *only* after all initial points have been evaluated " + + f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" + ) if self.base_optimizer._config_selector._model is None: - raise RuntimeError('Surrogate model is not yet trained') + raise RuntimeError("Surrogate model is not yet trained") - config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs)) + config_array: npt.NDArray = convert_configurations_to_array( + self._to_configspace_configs(configs=configs) + ) mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array) - return mean_predictions.reshape(-1,) + return mean_predictions.reshape( + -1, + ) - def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if self._space_adapter: @@ -348,13 +386,15 @@ def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.Da # pylint: disable=protected-access if self.base_optimizer._config_selector._acquisition_function is None: - raise RuntimeError('Acquisition function is not yet initialized') + raise RuntimeError("Acquisition function is not yet initialized") cs_configs: list = self._to_configspace_configs(configs=configs) - return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,) + return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape( + -1, + ) def cleanup(self) -> None: - if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None: + if hasattr(self, "_temp_output_directory") and self._temp_output_directory is not None: self._temp_output_directory.cleanup() self._temp_output_directory = None @@ -373,5 +413,5 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. """ return [ ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) - for (_, config) in configs.astype('O').iterrows() + for (_, config) in configs.astype("O").iterrows() ] diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 273c89eecc..2df19b8eb2 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -33,13 +33,16 @@ class FlamlOptimizer(BaseOptimizer): # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" - def __init__(self, *, # pylint: disable=too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - low_cost_partial_config: Optional[dict] = None, - seed: Optional[int] = None): + def __init__( + self, + *, # pylint: disable=too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + low_cost_partial_config: Optional[dict] = None, + seed: Optional[int] = None, + ): """ Create an MLOS wrapper for FLAML. @@ -82,14 +85,22 @@ def __init__(self, *, # pylint: disable=too-many-arguments configspace_to_flaml_space, ) - self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) + self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space( + self.optimizer_parameter_space + ) self.low_cost_partial_config = low_cost_partial_config self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {} self._suggested_config: Optional[dict] - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -111,9 +122,10 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, if metadata is not None: warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning) - for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()): + for (_, config), (_, score) in zip(configs.astype("O").iterrows(), scores.iterrows()): cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict()) + self.optimizer_parameter_space, values=config.to_dict() + ) if cs_config in self.evaluated_samples: warn(f"Configuration {config} was already registered", UserWarning) self.evaluated_samples[cs_config] = EvaluatedSample( @@ -121,7 +133,9 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, score=float(np.average(score.astype(float), weights=self._objective_weights)), ) - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -144,8 +158,13 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr config: dict = self._get_next_config() return pd.DataFrame(config, index=[0]), None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: @@ -200,16 +219,14 @@ def _get_next_config(self) -> dict: dict(normalize_config(self.optimizer_parameter_space, conf)) for conf in self.evaluated_samples ] - evaluated_rewards = [ - s.score for s in self.evaluated_samples.values() - ] + evaluated_rewards = [s.score for s in self.evaluated_samples.values()] # Warm start FLAML optimizer self._suggested_config = None tune.run( self._target_function, config=self.flaml_parameter_space, - mode='min', + mode="min", metric=self._METRIC_NAME, points_to_evaluate=points_to_evaluate, evaluated_rewards=evaluated_rewards, @@ -218,6 +235,6 @@ def _get_next_config(self) -> dict: verbose=0, ) if self._suggested_config is None: - raise RuntimeError('FLAML did not produce a suggestion') + raise RuntimeError("FLAML did not produce a suggestion") return self._suggested_config # type: ignore[unreachable] diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 4ab9db5a2f..f96bce7075 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -24,11 +24,14 @@ class BaseOptimizer(metaclass=ABCMeta): Optimizer abstract base class defining the basic interface. """ - def __init__(self, *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None): + def __init__( + self, + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + ): """ Create a new instance of the base optimizer. @@ -44,8 +47,9 @@ def __init__(self, *, The space adapter class to employ for parameter space transformations. """ self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space - self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \ + self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = ( parameter_space if space_adapter is None else space_adapter.target_parameter_space + ) if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space: raise ValueError("Given parameter space differs from the one given to space adapter") @@ -68,8 +72,14 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: """Get the space adapter instance (if any).""" return self._space_adapter - def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. Parameters @@ -87,29 +97,37 @@ def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, """ # Do some input validation. assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(scores.columns) == set(self._optimization_targets), \ - "Mismatched optimization targets." - assert self._has_context is None or self._has_context ^ (context is None), \ - "Context must always be added or never be added." - assert len(configs) == len(scores), \ - "Mismatched number of configs and scores." + assert set(scores.columns) == set( + self._optimization_targets + ), "Mismatched optimization targets." + assert self._has_context is None or self._has_context ^ ( + context is None + ), "Context must always be added or never be added." + assert len(configs) == len(scores), "Mismatched number of configs and scores." if context is not None: - assert len(configs) == len(context), \ - "Mismatched number of configs and context." - assert configs.shape[1] == len(self.parameter_space.values()), \ - "Mismatched configuration shape." + assert len(configs) == len(context), "Mismatched number of configs and context." + assert configs.shape[1] == len( + self.parameter_space.values() + ), "Mismatched configuration shape." self._observations.append((configs, scores, context)) self._has_context = context is not None if self._space_adapter: configs = self._space_adapter.inverse_transform(configs) - assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \ - "Mismatched configuration shape after inverse transform." + assert configs.shape[1] == len( + self.optimizer_parameter_space.values() + ), "Mismatched configuration shape after inverse transform." return self._register(configs=configs, scores=scores, context=context) @abstractmethod - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Parameters @@ -122,10 +140,11 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def suggest(self, *, context: Optional[pd.DataFrame] = None, - defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def suggest( + self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. @@ -149,18 +168,21 @@ def suggest(self, *, context: Optional[pd.DataFrame] = None, configuration = self.space_adapter.inverse_transform(configuration) else: configuration, metadata = self._suggest(context=context) - assert len(configuration) == 1, \ - "Suggest must return a single configuration." - assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \ - "Optimizer suggested a configuration that does not match the expected parameter space." + assert len(configuration) == 1, "Suggest must return a single configuration." + assert set(configuration.columns).issubset( + set(self.optimizer_parameter_space) + ), "Optimizer suggested a configuration that does not match the expected parameter space." if self._space_adapter: configuration = self._space_adapter.transform(configuration) - assert set(configuration.columns).issubset(set(self.parameter_space)), \ - "Space adapter produced a configuration that does not match the expected parameter space." + assert set(configuration.columns).issubset( + set(self.parameter_space) + ), "Space adapter produced a configuration that does not match the expected parameter space." return configuration, metadata @abstractmethod - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -176,12 +198,16 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr metadata : Optional[pd.DataFrame] The metadata associated with the given configuration used for evaluations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs as "pending". That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. @@ -195,7 +221,7 @@ def register_pending(self, *, configs: pd.DataFrame, metadata : Optional[pd.DataFrame] Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ @@ -210,11 +236,17 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data raise ValueError("No observations registered yet.") configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True) scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True) - contexts = pd.concat([pd.DataFrame() if context is None else context - for _, _, context in self._observations]).reset_index(drop=True) + contexts = pd.concat( + [ + pd.DataFrame() if context is None else context + for _, _, context in self._observations + ] + ).reset_index(drop=True) return (configs, scores, contexts if len(contexts.columns) > 0 else None) - def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_best_observations( + self, *, n_max: int = 1 + ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, context). Default is N=1. The columns are ordered in ASCENDING order of the optimization targets. @@ -234,8 +266,7 @@ def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.Dat raise ValueError("No observations registered yet.") (configs, scores, contexts) = self.get_observations() idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index - return (configs.loc[idx], scores.loc[idx], - None if contexts is None else contexts.loc[idx]) + return (configs.loc[idx], scores.loc[idx], None if contexts is None else contexts.loc[idx]) def cleanup(self) -> None: """ @@ -253,7 +284,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: j = 0 for param in self.optimizer_parameter_space.values(): if isinstance(param, ConfigSpace.CategoricalHyperparameter): - for (offset, val) in enumerate(param.choices): + for offset, val in enumerate(param.choices): if config[i][j + offset] == 1: df_dict[param.name].append(val) break diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index 0af785ef20..bf6f85ff88 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -24,8 +24,14 @@ class RandomOptimizer(BaseOptimizer): The parameter space to optimize. """ - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """Registers the given configs and scores. Doesn't do anything on the RandomOptimizer except storing configs for logging. @@ -50,7 +56,9 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning) # should we pop them from self.pending_observations? - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -71,9 +79,17 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr if context is not None: # not sure how that works here? warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) - return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None - - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + return ( + pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), + None, + ) + + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() # self._pending_observations.append((configs, context)) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 2e2f585590..73e7f37dc3 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -15,8 +15,8 @@ from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter __all__ = [ - 'IdentityAdapter', - 'LlamaTuneAdapter', + "IdentityAdapter", + "LlamaTuneAdapter", ] @@ -35,7 +35,7 @@ class SpaceAdapterType(Enum): # ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType]) # To address this, we add a test for complete coverage of the enum. ConcreteSpaceAdapter = TypeVar( - 'ConcreteSpaceAdapter', + "ConcreteSpaceAdapter", IdentityAdapter, LlamaTuneAdapter, ) @@ -47,10 +47,12 @@ class SpaceAdapterFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteSpaceAdapter: # type: ignore[type-var] """ Create a new space adapter instance, given the parameter space and potential space adapter options. @@ -75,8 +77,7 @@ def create(*, space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, - **space_adapter_kwargs + orig_parameter_space=parameter_space, **space_adapter_kwargs ) return space_adapter diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 6c3a86fc8a..58d07763f6 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -46,7 +46,7 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: """ Target parameter space that is fed to the underlying optimizer. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @@ -64,7 +64,7 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Pandas dataframe with a single row, containing the translated configuration. Column names are the parameter names of the original parameter space. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: @@ -84,4 +84,4 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: Dataframe of the translated configurations / parameters. The columns are the parameter names of the target parameter space and the rows are the configurations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 4d3a925cbc..b8abdedfeb 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -19,7 +19,7 @@ from mlos_core.util import normalize_config -class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes +class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes """ Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. @@ -28,18 +28,21 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- DEFAULT_NUM_LOW_DIMS = 16 """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" - DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 + DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 """Default percentage of bias for each special parameter value""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 """Default number of (max) unique values of each parameter, when space discretization is used""" - def __init__(self, *, - orig_parameter_space: ConfigSpace.ConfigurationSpace, - num_low_dims: int = DEFAULT_NUM_LOW_DIMS, - special_param_values: Optional[dict] = None, - max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, - use_approximate_reverse_mapping: bool = False): + def __init__( + self, + *, + orig_parameter_space: ConfigSpace.ConfigurationSpace, + num_low_dims: int = DEFAULT_NUM_LOW_DIMS, + special_param_values: Optional[dict] = None, + max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, + use_approximate_reverse_mapping: bool = False, + ): """ Create a space adapter that employs LlamaTune's techniques. @@ -58,7 +61,9 @@ def __init__(self, *, super().__init__(orig_parameter_space=orig_parameter_space) if num_low_dims >= len(orig_parameter_space): - raise ValueError("Number of target config space dimensions should be less than those of original config space.") + raise ValueError( + "Number of target config space dimensions should be less than those of original config space." + ) # Validate input special param values dict special_param_values = special_param_values or {} @@ -90,9 +95,10 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] - for (_, config) in configurations.astype('O').iterrows(): + for _, config in configurations.astype("O").iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict()) + self.orig_parameter_space, values=config.to_dict() + ) target_config = self._suggested_configs.get(configuration, None) # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection @@ -104,12 +110,15 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # Default configuration should always be registerable. pass elif not self._use_approximate_reverse_mapping: - raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. " - "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " - "previously by the optimizer can be registered.") + raise ValueError( + f"{repr(configuration)}\n" + "The above configuration was not suggested by the optimizer. " + "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " + "previously by the optimizer can be registered." + ) # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. - if getattr(self, '_pinv_matrix', None) is None: + if getattr(self, "_pinv_matrix", None) is None: self._try_generate_approx_inverse_mapping() # Replace NaNs with zeros for inactive hyperparameters @@ -118,19 +127,27 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # NOTE: applying special value biasing is not possible vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) - target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector) + target_config = ConfigSpace.Configuration( + self.target_parameter_space, vector=target_config_vector + ) target_configurations.append(target_config) - return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys())) + return pd.DataFrame( + target_configurations, columns=list(self.target_parameter_space.keys()) + ) def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: if len(configuration) != 1: - raise ValueError("Configuration dataframe must contain exactly 1 row. " - f"Found {len(configuration)} rows.") + raise ValueError( + "Configuration dataframe must contain exactly 1 row. " + f"Found {len(configuration)} rows." + ) target_values_dict = configuration.iloc[0].to_dict() - target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict) + target_configuration = ConfigSpace.Configuration( + self.target_parameter_space, values=target_values_dict + ) orig_values_dict = self._transform(target_values_dict) orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict) @@ -138,9 +155,13 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: # Add to inverse dictionary -- needed for registering the performance later self._suggested_configs[orig_configuration] = target_configuration - return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) + return pd.DataFrame( + [list(orig_configuration.values())], columns=list(orig_configuration.keys()) + ) - def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: + def _construct_low_dim_space( + self, num_low_dims: int, max_unique_values_per_param: Optional[int] + ) -> None: """Constructs the low-dimensional parameter (potentially discretized) search space. Parameters @@ -156,7 +177,7 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para q_scaler = None if max_unique_values_per_param is None: hyperparameters = [ - ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1) + ConfigSpace.UniformFloatHyperparameter(name=f"dim_{idx}", lower=-1, upper=1) for idx in range(num_low_dims) ] else: @@ -164,7 +185,9 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. hyperparameters = [ - ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param) + ConfigSpace.UniformIntegerHyperparameter( + name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param + ) for idx in range(num_low_dims) ] @@ -178,7 +201,9 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para # Construct low-dimensional parameter search space config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) - config_space.random = self._random_state # use same random state as in original parameter space + config_space.random = ( + self._random_state + ) # use same random state as in original parameter space config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -216,10 +241,10 @@ def _transform(self, configuration: dict) -> dict: # Clip value to force it to fall in [0, 1] # NOTE: HeSBO projection ensures that theoretically but due to # floating point ops nuances this is not always guaranteed - value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name + value = max(0.0, min(1.0, norm_value)) # pylint: disable=redefined-loop-name if isinstance(param, ConfigSpace.CategoricalHyperparameter): - index = int(value * len(param.choices)) # truncate integer part + index = int(value * len(param.choices)) # truncate integer part index = max(0, min(len(param.choices) - 1, index)) # NOTE: potential rounding here would be unfair to first & last values orig_value = param.choices[index] @@ -227,16 +252,20 @@ def _transform(self, configuration: dict) -> dict: if param.name in self._special_param_values_dict: value = self._special_param_value_scaler(param, value) - orig_value = param._transform(value) # pylint: disable=protected-access + orig_value = param._transform(value) # pylint: disable=protected-access orig_value = max(param.lower, min(param.upper, orig_value)) else: - raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.") + raise NotImplementedError( + "Only Categorical, Integer, and Float hyperparameters are currently supported." + ) original_config[param.name] = orig_value return original_config - def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: + def _special_param_value_scaler( + self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float + ) -> float: """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. Parameters @@ -255,7 +284,7 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara special_values_list = self._special_param_values_dict[param.name] # Check if input value corresponds to some special value - perc_sum = 0. + perc_sum = 0.0 ret: float for special_value, biasing_perc in special_values_list: perc_sum += biasing_perc @@ -264,8 +293,9 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara return ret # Scale input value uniformly to non-special values - ret = param._inverse_transform( # pylint: disable=protected-access - param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access + ret = param._inverse_transform( # pylint: disable=protected-access + param._transform_scalar((input_value - perc_sum) / (1 - perc_sum)) + ) # pylint: disable=protected-access return ret # pylint: disable=too-complex,too-many-branches @@ -294,8 +324,10 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non hyperparameter = self.orig_parameter_space[param] if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter): - raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. " - "Only Integer Hyperparameters are currently supported.") + raise NotImplementedError( + error_prefix + f"Parameter '{param}' is not supported. " + "Only Integer Hyperparameters are currently supported." + ) if isinstance(value, int): # User specifies a single special value -- default biasing percentage is used @@ -306,34 +338,57 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non elif isinstance(value, list) and value: if all(isinstance(t, int) for t in value): # User specifies list of special values - tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value] - elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value): + tuple_list = [ + (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value + ] + elif all( + isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value + ): # User specifies list of tuples; each tuple defines the special value and the biasing percentage tuple_list = value else: - raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. " - f"Special value list should contain either integers, or (special value, biasing %) tuples.") + raise ValueError( + error_prefix + f"Invalid format in value list for parameter '{param}'. " + f"Special value list should contain either integers, or (special value, biasing %) tuples." + ) else: - raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " - "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.") + raise ValueError( + error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " + "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples." + ) # Are user-specified special values valid? if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.") + raise ValueError( + error_prefix + + f"One (or more) special values are outside of parameter '{param}' value domain." + ) # Are user-provided special values unique? if len(set(v for v, _ in tuple_list)) != len(tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.") + raise ValueError( + error_prefix + + f"One (or more) special values are defined more than once for parameter '{param}'." + ) # Are biasing percentages valid? if not all(0 < perc < 1 for _, perc in tuple_list): - raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: " - "i.e., fall outside (0, 1) range.") + raise ValueError( + error_prefix + + f"One (or more) biasing percentages for parameter '{param}' are invalid: " + "i.e., fall outside (0, 1) range." + ) total_percentage = sum(perc for _, perc in tuple_list) - if total_percentage >= 1.: - raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.") + if total_percentage >= 1.0: + raise ValueError( + error_prefix + + f"Total special values percentage for parameter '{param}' surpass 100%." + ) # ... and reasonable? if total_percentage >= 0.5: - warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning) + warn( + f"Total special values percentage for parameter '{param}' exceeds 50%.", + UserWarning, + ) sanitized_dict[param] = tuple_list @@ -355,9 +410,12 @@ def _try_generate_approx_inverse_mapping(self) -> None: pinv, ) - warn("Trying to register a configuration that was not previously suggested by the optimizer. " + - "This inverse configuration transformation is typically not supported. " + - "However, we will try to register this configuration using an *experimental* method.", UserWarning) + warn( + "Trying to register a configuration that was not previously suggested by the optimizer. " + + "This inverse configuration transformation is typically not supported. " + + "However, we will try to register this configuration using an *experimental* method.", + UserWarning, + ) orig_space_num_dims = len(list(self.orig_parameter_space.values())) target_space_num_dims = len(list(self.target_parameter_space.values())) @@ -371,5 +429,7 @@ def _try_generate_approx_inverse_mapping(self) -> None: try: self._pinv_matrix = pinv(proj_matrix) except LinAlgError as err: - raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err + raise RuntimeError( + f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}" + ) from err assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims) diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index d6918f9891..1b9e61ad91 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -27,7 +27,9 @@ FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] -def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: +def configspace_to_flaml_space( + config_space: ConfigSpace.ConfigurationSpace, +) -> Dict[str, FlamlDomain]: """Converts a ConfigSpace.ConfigurationSpace to dict. Parameters @@ -50,13 +52,19 @@ def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + ) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + 1 + ) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: - raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.") - return flaml.tune.choice(parameter.choices) # TODO: set order? + raise ValueError( + "FLAML doesn't support categorical parameters with non-uniform probabilities." + ) + return flaml.tune.choice(parameter.choices) # TODO: set order? raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") return {param.name: _one_parameter_convert(param) for param in config_space.values()} diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index a8ad146205..81d5151b20 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -21,7 +21,7 @@ from typing_extensions import TypeAlias -T = TypeVar('T') +T = TypeVar("T") def get_all_submodules(pkg: TypeAlias) -> List[str]: @@ -30,7 +30,9 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: Useful for dynamically enumerating subclasses. """ submodules = [] - for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): + for _, submodule_name, _ in walk_packages( + pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None + ): submodules.append(submodule_name) return submodules @@ -41,7 +43,8 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( - s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) + s for c in cls.__subclasses__() for s in _get_all_subclasses(c) + ) def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: @@ -57,5 +60,11 @@ def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> pkg = import_module(pkg_name) submodules = get_all_submodules(pkg) assert submodules - return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], - key=lambda c: (c.__module__, c.__name__)) + return sorted( + [ + subclass + for subclass in _get_all_subclasses(cls) + if not getattr(subclass, "__abstractmethods__", None) + ], + key=lambda c: (c.__module__, c.__name__), + ) diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index c7a94dfcc4..775afa2455 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -17,24 +17,27 @@ @pytest.mark.filterwarnings("error:Not Implemented") -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_context_not_implemented_warning( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Make sure we raise warnings for the functionality that has not been implemented yet. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) suggestion, _metadata = optimizer.suggest() - scores = pd.DataFrame({'score': [1]}) + scores = pd.DataFrame({"score": [1]}) context = pd.DataFrame([["something"]]) with pytest.raises(UserWarning): diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 39231bec5c..504c91eac7 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -18,9 +18,9 @@ def configuration_space() -> CS.ConfigurationSpace: # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. - space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) # Add a categorical hyperparameter with 3 possible values. - space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"])) + space.add_hyperparameter(CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"])) # Add a discrete input dimension between 0 and 10. - space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10)) + space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10)) return space diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 725d92fbe9..7fe793a824 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -23,11 +23,13 @@ def data_frame() -> pd.DataFrame: Toy data frame corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.DataFrame({ - 'y': ['a', 'b', 'c'], - 'x': [0.1, 0.2, 0.3], - 'z': [1, 5, 8], - }) + return pd.DataFrame( + { + "y": ["a", "b", "c"], + "x": [0.1, 0.2, 0.3], + "z": [1, 5, 8], + } + ) @pytest.fixture @@ -36,11 +38,13 @@ def one_hot_data_frame() -> npt.NDArray: One-hot encoding of the `data_frame` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.1, 1.0, 0.0, 0.0, 1.0], - [0.2, 0.0, 1.0, 0.0, 5.0], - [0.3, 0.0, 0.0, 1.0, 8.0], - ]) + return np.array( + [ + [0.1, 1.0, 0.0, 0.0, 1.0], + [0.2, 0.0, 1.0, 0.0, 5.0], + [0.3, 0.0, 0.0, 1.0, 8.0], + ] + ) @pytest.fixture @@ -49,11 +53,13 @@ def series() -> pd.Series: Toy series corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.Series({ - 'y': 'b', - 'x': 0.4, - 'z': 3, - }) + return pd.Series( + { + "y": "b", + "x": 0.4, + "z": 3, + } + ) @pytest.fixture @@ -62,9 +68,11 @@ def one_hot_series() -> npt.NDArray: One-hot encoding of the `series` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.4, 0.0, 1.0, 0.0, 3], - ]) + return np.array( + [ + [0.4, 0.0, 1.0, 0.0, 3], + ] + ) @pytest.fixture @@ -74,39 +82,40 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ return SmacOptimizer( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], ) -def test_to_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_to_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """ Toy problem to test one-hot encoding of dataframe. """ assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) -def test_to_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, one_hot_series: npt.NDArray) -> None: +def test_to_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """ Toy problem to test one-hot encoding of series. """ assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) -def test_from_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_from_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """ Toy problem to test one-hot decoding of dataframe. """ assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() -def test_from_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, - one_hot_series: npt.NDArray) -> None: +def test_from_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """ Toy problem to test one-hot decoding of series. """ @@ -135,8 +144,9 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No assert (series_round_trip.z == series.z).all() -def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, - one_hot_data_frame: npt.NDArray) -> None: +def test_round_trip_reverse_data_frame( + optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray +) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ @@ -144,8 +154,7 @@ def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, assert round_trip == pytest.approx(one_hot_data_frame) -def test_round_trip_reverse_series(optimizer: BaseOptimizer, - one_hot_series: npt.NDArray) -> None: +def test_round_trip_reverse_series(optimizer: BaseOptimizer, one_hot_series: npt.NDArray) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 0b9d624a7a..870943c346 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -20,10 +20,15 @@ _LOG = logging.getLogger(__name__) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt_wrong_weights( + optimizer_class: Type[BaseOptimizer], kwargs: dict +) -> None: """ Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. @@ -31,23 +36,29 @@ def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kw with pytest.raises(ValueError): optimizer_class( parameter_space=CS.ConfigurationSpace(seed=SEED), - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=[1], - **kwargs + **kwargs, ) -@pytest.mark.parametrize(('objective_weights'), [ - [2, 1], - [0.5, 0.5], - None, -]) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt(objective_weights: Optional[List[float]], - optimizer_class: Type[BaseOptimizer], - kwargs: dict) -> None: +@pytest.mark.parametrize( + ("objective_weights"), + [ + [2, 1], + [0.5, 0.5], + None, + ], +) +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt( + objective_weights: Optional[List[float]], optimizer_class: Type[BaseOptimizer], kwargs: dict +) -> None: """ Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. @@ -56,21 +67,21 @@ def test_multi_target_opt(objective_weights: Optional[List[float]], def objective(point: pd.DataFrame) -> pd.DataFrame: # mix of hyperparameters, optimal is to select the highest possible - return pd.DataFrame({ - "main_score": point.x + point.y, - "other_score": point.x ** 2 + point.y ** 2, - }) + return pd.DataFrame( + { + "main_score": point.x + point.y, + "other_score": point.x**2 + point.y**2, + } + ) input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) optimizer = optimizer_class( parameter_space=input_space, - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=objective_weights, **kwargs, ) @@ -85,27 +96,28 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y'} + assert set(suggestion.columns) == {"x", "y"} # Check suggestion values are the expected dtype assert isinstance(suggestion.x.iloc[0], np.integer) assert isinstance(suggestion.y.iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. observation = objective(suggestion) assert isinstance(observation, pd.DataFrame) - assert set(observation.columns) == {'main_score', 'other_score'} + assert set(observation.columns) == {"main_score", "other_score"} optimizer.register(configs=suggestion, scores=observation) (best_config, best_score, best_context) = optimizer.get_best_observations() assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'main_score', 'other_score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"main_score", "other_score"} assert best_config.shape == (1, 2) assert best_score.shape == (1, 2) @@ -113,7 +125,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'main_score', 'other_score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"main_score", "other_score"} assert all_configs.shape == (max_iterations, 2) assert all_scores.shape == (max_iterations, 2) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 5fd28ca1ed..d5d00d0692 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -32,20 +32,24 @@ _LOG.setLevel(logging.DEBUG) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_create_optimizer_and_suggest( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Test that we can create an optimizer and get a suggestion from it. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) assert optimizer is not None @@ -62,11 +66,17 @@ def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace optimizer.register_pending(configs=suggestion, metadata=metadata) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_basic_interface_toy_problem( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """ Toy problem to test the optimizers. """ @@ -77,17 +87,15 @@ def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, if optimizer_class == OptimizerType.SMAC.value: # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. - kwargs['max_trials'] = max_iterations * 2 + kwargs["max_trials"] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: - return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)}) + return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)}) # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) with pytest.raises(ValueError, match="No observations"): @@ -100,12 +108,12 @@ def objective(x: pd.Series) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y', 'z'} + assert set(suggestion.columns) == {"x", "y", "z"} # check that suggestion is in the space configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict()) # Raises an error if outside of configuration space configuration.is_valid_configuration() - observation = objective(suggestion['x']) + observation = objective(suggestion["x"]) assert isinstance(observation, pd.DataFrame) optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -113,8 +121,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y', 'z'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y", "z"} + assert set(best_score.columns) == {"score"} assert best_config.shape == (1, 3) assert best_score.shape == (1, 1) assert best_score.score.iloc[0] < -5 @@ -123,8 +131,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y', 'z'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y", "z"} + assert set(all_scores.columns) == {"score"} assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) @@ -137,27 +145,36 @@ def objective(x: pd.Series) -> pd.DataFrame: assert pred_all.shape == (20,) -@pytest.mark.parametrize(('optimizer_type'), [ - # Enumerate all supported Optimizers - # *[member for member in OptimizerType], - *list(OptimizerType), -]) +@pytest.mark.parametrize( + ("optimizer_type"), + [ + # Enumerate all supported Optimizers + # *[member for member in OptimizerType], + *list(OptimizerType), + ], +) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ - assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member - - -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: + assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member + + +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_create_optimizer_with_factory_method( + configuration_space: CS.ConfigurationSpace, + optimizer_type: Optional[OptimizerType], + kwargs: Optional[dict], +) -> None: """ Test that we can create an optimizer via a factory. """ @@ -166,13 +183,13 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -188,16 +205,22 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat assert myrepr.startswith(optimizer_type.value.__name__) -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - (OptimizerType.SMAC, { - # Test with default config. - 'use_default_config': True, - # 'n_random_init': 10, - }), -]) +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ( + OptimizerType.SMAC, + { + # Test with default config. + "use_default_config": True, + # 'n_random_init': 10, + }, + ), + ], +) def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers with llamatune space adapter. @@ -215,8 +238,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=1234) # Add two continuous inputs - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=3)) # Initialize an optimizer that uses LlamaTune space adapter space_adapter_kwargs = { @@ -239,7 +262,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: llamatune_optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=llamatune_optimizer_kwargs, space_adapter_type=SpaceAdapterType.LLAMATUNE, @@ -248,7 +271,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Initialize an optimizer that uses the original space optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=optimizer_kwargs, ) @@ -257,7 +280,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space llamatune_n_random_init = 0 - opt_n_random_init = int(kwargs.get('n_random_init', 0)) + opt_n_random_init = int(kwargs.get("n_random_init", 0)) if optimizer_type == OptimizerType.SMAC: assert isinstance(optimizer, SmacOptimizer) assert isinstance(llamatune_optimizer, SmacOptimizer) @@ -278,8 +301,10 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() - _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space + _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( + 3.0, rel=1e-3 + ) # optimizer explores 1-dimensional space observation = objective(suggestion) llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -287,28 +312,32 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: best_observation = optimizer.get_best_observations() llamatune_best_observation = llamatune_optimizer.get_best_observations() - for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation): + for best_config, best_score, best_context in (best_observation, llamatune_best_observation): assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"score"} (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that - assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \ - best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + assert ( + best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] + or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + ) # Retrieve and check all observations - for (all_configs, all_scores, all_contexts) in ( - optimizer.get_observations(), llamatune_optimizer.get_observations()): + for all_configs, all_scores, all_contexts in ( + optimizer.get_observations(), + llamatune_optimizer.get_observations(), + ): assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"score"} assert len(all_configs) == num_iters assert len(all_scores) == num_iters @@ -320,12 +349,13 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract] - pkg_name='mlos_core') +optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( + BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] +) assert optimizer_subclasses -@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) +@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: """ Test that all optimizer classes are listed in the OptimizerType enum. @@ -334,14 +364,19 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: assert optimizer_class in optimizer_type_classes -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_mixed_numerics_type_input_space_types( + optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] +) -> None: """ Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -355,19 +390,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -381,12 +416,14 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) - assert (suggestion.columns == ['x', 'y']).all() + assert (suggestion.columns == ["x", "y"]).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion['x'].iloc[0], np.integer) - assert isinstance(suggestion['y'].iloc[0], np.floating) + assert isinstance(suggestion["x"].iloc[0], np.integer) + assert isinstance(suggestion["y"].iloc[0], np.floating) # Check that suggestion is in the space - test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + test_configuration = CS.Configuration( + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 37b8aa3a69..13a28d242d 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -20,22 +20,33 @@ def test_identity_adapter() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100)) + CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) adapter = IdentityAdapter(orig_parameter_space=input_space) num_configs = 10 - for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + for sampled_config in input_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # (false positive) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) assert target_config_df.equals(sampled_config_df) - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config orig_config_df = adapter.transform(target_config_df) assert orig_config_df.equals(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) assert orig_config == sampled_config diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 84dcd4e5c0..cd1b250ab7 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -30,34 +30,46 @@ def construct_parameter_space( for idx in range(n_continuous_params): input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64)) + CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64) + ) for idx in range(n_integer_params): input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256)) + CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256) + ) for idx in range(n_categorical_params): input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)])) + CS.CategoricalHyperparameter( + name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)] + ) + ) return input_space -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_num_low_dims( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's low-to-high space projection method. """ @@ -66,8 +78,7 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N # Number of target parameter space dimensions should be fewer than those of the original space with pytest.raises(ValueError): LlamaTuneAdapter( - orig_parameter_space=input_space, - num_low_dims=len(list(input_space.keys())) + orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys())) ) # Enable only low-dimensional space projections @@ -75,13 +86,15 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N orig_parameter_space=input_space, num_low_dims=num_target_space_dims, special_param_values=None, - max_unique_values_per_param=None + max_unique_values_per_param=None, ) sampled_configs = adapter.target_parameter_space.sample_configuration(size=100) for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive) # Transform low-dim config to high-dim point/config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid @@ -92,18 +105,28 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config # Try inverse projection (i.e., high-to-low) for previously unseen configs unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25) - for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) - if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive) + for ( + unseen_sampled_config + ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) + if ( + unseen_sampled_config in sampled_configs + ): # pylint: disable=unsupported-membership-test # (false positive) continue - unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())) + unseen_sampled_config_df = pd.DataFrame( + [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()) + ) with pytest.raises(ValueError): - _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type + _ = adapter.inverse_transform( + unseen_sampled_config_df + ) # pylint: disable=redefined-variable-type def test_special_parameter_values_validation() -> None: @@ -112,15 +135,14 @@ def test_special_parameter_values_validation() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)])) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100)) - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100)) + CS.CategoricalHyperparameter(name="str", choices=[f"choice_{idx}" for idx in range(5)]) + ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100)) # Only UniformIntegerHyperparameters are currently supported with pytest.raises(NotImplementedError): - special_param_values_dict_1 = {'str': 'choice_1'} + special_param_values_dict_1 = {"str": "choice_1"} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -129,7 +151,7 @@ def test_special_parameter_values_validation() -> None: ) with pytest.raises(NotImplementedError): - special_param_values_dict_2 = {'cont': -1} + special_param_values_dict_2 = {"cont": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -138,8 +160,8 @@ def test_special_parameter_values_validation() -> None: ) # Special value should belong to parameter value domain - with pytest.raises(ValueError, match='value domain'): - special_param_values_dict = {'int': -1} + with pytest.raises(ValueError, match="value domain"): + special_param_values_dict = {"int": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -149,15 +171,15 @@ def test_special_parameter_values_validation() -> None: # Invalid dicts; ValueError should be thrown invalid_special_param_values_dicts: List[Dict[str, Any]] = [ - {'int-Q': 0}, # parameter does not exist - {'int': {0: 0.2}}, # invalid definition - {'int': 0.2}, # invalid parameter value - {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) - {'int': [0, 0]}, # duplicate special values - {'int': []}, # empty list - {'int': [{0: 0.2}]}, - {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct - {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values + {"int-Q": 0}, # parameter does not exist + {"int": {0: 0.2}}, # invalid definition + {"int": 0.2}, # invalid parameter value + {"int": (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) + {"int": [0, 0]}, # duplicate special values + {"int": []}, # empty list + {"int": [{0: 0.2}]}, + {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct + {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values ] for spv_dict in invalid_special_param_values_dicts: with pytest.raises(ValueError): @@ -170,13 +192,13 @@ def test_special_parameter_values_validation() -> None: # Biasing percentage of special value(s) are invalid invalid_special_param_values_dicts = [ - {'int': (0, 1.1)}, # >1 probability - {'int': (0, 0)}, # Zero probability - {'int': (0, -0.1)}, # Negative probability - {'int': (0, 20)}, # 2,000% instead of 20% - {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% - {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100% - {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. + {"int": (0, 1.1)}, # >1 probability + {"int": (0, 0)}, # Zero probability + {"int": (0, -0.1)}, # Negative probability + {"int": (0, 20)}, # 2,000% instead of 20% + {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% + {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100% + {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. ] for spv_dict in invalid_special_param_values_dicts: @@ -192,21 +214,27 @@ def test_special_parameter_values_validation() -> None: def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]: for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs): # Transform low-dim config to high-dim config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) yield orig_config -def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex +def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex """ Tests LlamaTune's special parameter values biasing methodology """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100) + ) num_configs = 400 bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE @@ -214,10 +242,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co # Single parameter; single special value special_param_value_dicts: List[Dict[str, Any]] = [ - {'int_1': 0}, - {'int_1': (0, bias_percentage)}, - {'int_1': [0]}, - {'int_1': [(0, bias_percentage)]} + {"int_1": 0}, + {"int_1": (0, bias_percentage)}, + {"int_1": [0]}, + {"int_1": [(0, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -229,13 +257,14 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_value_occurrences = sum( - 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0) + 1 for config in gen_random_configs(adapter, num_configs) if config["int_1"] == 0 + ) assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences # Single parameter; multiple special values special_param_value_dicts = [ - {'int_1': [0, 1]}, - {'int_1': [(0, bias_percentage), (1, bias_percentage)]} + {"int_1": [0, 1]}, + {"int_1": [(0, bias_percentage), (1, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -248,9 +277,9 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co special_values_occurrences = {0: 0, 1: 0} for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: + if config["int_1"] == 0: special_values_occurrences[0] += 1 - elif config['int_1'] == 1: + elif config["int_1"] == 1: special_values_occurrences[1] += 1 assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0] @@ -258,8 +287,8 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co # Multiple parameters; multiple special values; different biasing percentage spv_dict = { - 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)], - 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)] + "int_1": [(0, bias_percentage), (1, bias_percentage / 2)], + "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)], } adapter = LlamaTuneAdapter( orig_parameter_space=input_space, @@ -269,24 +298,30 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_values_instances: Dict[str, Dict[int, int]] = { - 'int_1': {0: 0, 1: 0}, - 'int_2': {2: 0, 100: 0}, + "int_1": {0: 0, 1: 0}, + "int_2": {2: 0, 100: 0}, } for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: - special_values_instances['int_1'][0] += 1 - elif config['int_1'] == 1: - special_values_instances['int_1'][1] += 1 - - if config['int_2'] == 2: - special_values_instances['int_2'][2] += 1 - elif config['int_2'] == 100: - special_values_instances['int_2'][100] += 1 - - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2] - assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100] + if config["int_1"] == 0: + special_values_instances["int_1"][0] += 1 + elif config["int_1"] == 1: + special_values_instances["int_1"][1] += 1 + + if config["int_2"] == 2: + special_values_instances["int_2"][2] += 1 + elif config["int_2"] == 100: + special_values_instances["int_2"][100] += 1 + + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances["int_1"][0] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_1"][ + 1 + ] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_2"][ + 2 + ] + assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances[ + "int_2" + ][100] def test_max_unique_values_per_param() -> None: @@ -295,18 +330,22 @@ def test_max_unique_values_per_param() -> None: """ # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5)) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100)) + CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10)) + CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)])) + CS.CategoricalHyperparameter(name="str_2", choices=[f"choice_{idx}" for idx in range(10)]) + ) # Restrict the number of unique parameter values num_configs = 200 @@ -329,23 +368,30 @@ def test_max_unique_values_per_param() -> None: assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_approx_inverse_mapping( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse. """ @@ -360,9 +406,11 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: use_approximate_reverse_mapping=False, ) - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.raises(ValueError): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) _ = adapter.inverse_transform(sampled_config_df) # Enable low-dimensional space projection *and* reverse mapping @@ -375,41 +423,63 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: ) # Warning should be printed the first time - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.warns(UserWarning): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) # Test inverse transform with 100 random configs for _ in range(100): - sampled_config = input_space.sample_configuration() # size=1) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config = input_space.sample_configuration() # size=1) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) -@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([ - (num_low_dims, special_param_values, max_unique_values_per_param) - for num_low_dims in (8, 16) - for special_param_values in ( - {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]}, - {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]}, - ) - for max_unique_values_per_param in (50, 250) -])) -def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: +@pytest.mark.parametrize( + ("num_low_dims", "special_param_values", "max_unique_values_per_param"), + ( + [ + (num_low_dims, special_param_values, max_unique_values_per_param) + for num_low_dims in (8, 16) + for special_param_values in ( + {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]}, + { + "int_1": (-1, 0.1), + "int_2": -1, + "int_3": (-1, 0.3), + "int_4": [(-1, 0.1), (0, 0.2)], + }, + ) + for max_unique_values_per_param in (50, 250) + ] + ), +) +def test_llamatune_pipeline( + num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int +) -> None: """ Tests LlamaTune space adapter when all components are active. """ # pylint: disable=too-many-locals # Define config space with a mix of different parameter types - input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5) + input_space = construct_parameter_space( + n_continuous_params=10, n_integer_params=10, n_categorical_params=5 + ) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=num_low_dims, @@ -419,12 +489,14 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u special_value_occurrences = { param: {special_value: 0 for special_value, _ in tuples_list} - for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access } unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()} num_configs = 1000 - for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable + for config in adapter.target_parameter_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # Transform low-dim config to high-dim point/config sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys())) orig_config_df = adapter.transform(sampled_config_df) @@ -435,7 +507,9 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == config for param, value in orig_config.items(): @@ -449,35 +523,48 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u # Ensure that occurrences of special values do not significantly deviate from expected eps = 0.2 - for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access + for ( + param, + tuples_list, + ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access for value, bias_percentage in tuples_list: - assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[ + param + ][value] # Ensure that number of unique values is less than the maximum number allowed for _, unique_values in unique_values_dict.items(): assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_deterministic_behavior_for_same_seed( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: """ Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ + def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) @@ -490,7 +577,9 @@ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: use_approximate_reverse_mapping=False, ) - sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100) + sample_configs: List[CS.Configuration] = ( + adapter.target_parameter_space.sample_configuration(size=100) + ) return sample_configs assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 5390f97c5f..6e5eab7d96 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -23,39 +23,47 @@ from mlos_core.tests import get_all_concrete_subclasses -@pytest.mark.parametrize(('space_adapter_type'), [ - # Enumerate all supported SpaceAdapters - # *[member for member in SpaceAdapterType], - *list(SpaceAdapterType), -]) +@pytest.mark.parametrize( + ("space_adapter_type"), + [ + # Enumerate all supported SpaceAdapters + # *[member for member in SpaceAdapterType], + *list(SpaceAdapterType), + ], +) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] + assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] -@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [ - # Default space adapter - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in SpaceAdapterType], -]) -def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("space_adapter_type", "kwargs"), + [ + # Default space adapter + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in SpaceAdapterType], + ], +) +def test_create_space_adapter_with_factory_method( + space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] +) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=1)) # Adjust some kwargs for specific space adapters if space_adapter_type is SpaceAdapterType.LLAMATUNE: if kwargs is None: kwargs = {} - kwargs.setdefault('num_low_dims', 1) + kwargs.setdefault("num_low_dims", 1) space_adapter: BaseSpaceAdapter if space_adapter_type is None: @@ -73,21 +81,25 @@ def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[S assert space_adapter is not None assert space_adapter.orig_parameter_space is not None myrepr = repr(space_adapter) - assert myrepr.startswith(space_adapter_type.value.__name__), \ - f"Expected {space_adapter_type.value.__name__} but got {myrepr}" + assert myrepr.startswith( + space_adapter_type.value.__name__ + ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}" # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \ - get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract] +space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses( + BaseSpaceAdapter, pkg_name="mlos_core" +) # type: ignore[type-abstract] assert space_adapter_subclasses -@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) +@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: """ Test that all space adapter classes are listed in the SpaceAdapterType enum. """ - space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} + space_adapter_type_classes = { + space_adapter_type.value for space_adapter_type in SpaceAdapterType + } assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index dee9251652..f7cde8ae88 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -41,9 +41,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None: assert np.isclose(frequencies.sum(), 1) _f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies) - assert np.isclose(kurtosis, -1.2, atol=.1) - assert p_value > .3 - assert f_p_value > .5 + assert np.isclose(kurtosis, -1.2, atol=0.1) + assert p_value > 0.3 + assert f_p_value > 0.5 def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None: @@ -70,13 +70,14 @@ def invalid_conversion_function(*args: Any) -> NoReturn: """ A quick dummy function for the base class to make pylint happy. """ - raise NotImplementedError('subclass must override conversion_function') + raise NotImplementedError("subclass must override conversion_function") class BaseConversion(metaclass=ABCMeta): """ Base class for testing optimizer space conversions. """ + conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod @@ -150,8 +151,8 @@ def test_uniform_samples(self) -> None: assert_is_uniform(uniform) # Check that we get both ends of the sampled range returned to us. - assert input_space['c'].lower in integer_uniform - assert input_space['c'].upper in integer_uniform + assert input_space["c"].lower in integer_uniform + assert input_space["c"].upper in integer_uniform # integer uniform assert_is_uniform(integer_uniform) @@ -165,13 +166,13 @@ def test_uniform_categorical(self) -> None: assert 35 < counts[1] < 65 def test_weighted_categorical(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_int_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_float_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") class TestFlamlConversion(BaseConversion): @@ -184,10 +185,12 @@ class TestFlamlConversion(BaseConversion): def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override] assert isinstance(config_space, dict) assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) - ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T + ret: npt.NDArray = np.array( + [domain.sample(size=n_samples) for domain in config_space.values()] + ).T return ret - def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] + def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] assert isinstance(config_space, dict) ret: List[str] = list(config_space.keys()) return ret @@ -208,7 +211,9 @@ def test_dimensionality(self) -> None: def test_weighted_categorical(self) -> None: np.random.seed(42) input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1]) + ) with pytest.raises(ValueError, match="non-uniform"): configspace_to_flaml_space(input_space) @@ -217,7 +222,9 @@ def test_log_int_spaces(self) -> None: np.random.seed(42) # integer is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -235,7 +242,9 @@ def test_log_float_spaces(self) -> None: # continuous is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -245,6 +254,6 @@ def test_log_float_spaces(self) -> None: assert_is_log_uniform(float_log_uniform) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-k", "test_log_int_spaces", __file__]) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index df0e144535..e6fa12522e 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -28,7 +28,9 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: return pd.DataFrame([dict(config)]) -def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration: +def normalize_config( + config_space: ConfigurationSpace, config: Union[Configuration, dict] +) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. @@ -49,8 +51,6 @@ def normalize_config(config_space: ConfigurationSpace, config: Union[Configurati """ cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True) return Configuration( - config_space, values={ - key: cs_config[key] - for key in config_space.get_active_hyperparameters(cs_config) - } + config_space, + values={key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config)}, ) diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index 2362de7083..f946f94aa4 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index fed376d1af..4d895db315 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -21,15 +21,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -49,52 +50,54 @@ # we return nothing when the file is not available. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0 + "flaml": ["flaml[blendsearch]"], + "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0 } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'scikit-learn>=1.2', - 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released - 'scipy>=1.3.2', - 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710 - 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', + "scikit-learn>=1.2", + "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released + "scipy>=1.3.2", + "numpy>=1.24", + "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 + 'pandas >= 2.2.0;python_version>="3.9"', + 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - 'ConfigSpace>=0.7.1', + "ConfigSpace>=0.7.1", ], extras_require=extra_requires, **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"), diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index 2390554e1e..1725a24ed9 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -23,7 +23,7 @@ class MlosVizMethod(Enum): """ DABL = "dabl" - AUTO = DABL # use dabl as the current default + AUTO = DABL # use dabl as the current default def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: @@ -39,17 +39,21 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - plotter_method: MlosVizMethod = MlosVizMethod.AUTO, - filter_warnings: bool = True, - **kwargs: Any) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + plotter_method: MlosVizMethod = MlosVizMethod.AUTO, + filter_warnings: bool = True, + **kwargs: Any, +) -> None: """ Plots the results of the experiment. @@ -81,6 +85,7 @@ def plot(exp_data: Optional[ExperimentData] = None, *, if MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 15358b0862..d2fc4edad7 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -20,7 +20,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args -_SEABORN_VERS = version('seaborn') +_SEABORN_VERS = version("seaborn") def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: @@ -30,7 +30,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: Note: this only works with non-positional kwargs (e.g., those after a * arg). """ target_kwargs = {} - for kword in target.__kwdefaults__: # or {} # intentionally omitted for now + for kword in target.__kwdefaults__: # or {} # intentionally omitted for now if kword in kwargs: target_kwargs[kword] = kwargs[kword] return target_kwargs @@ -42,14 +42,19 @@ def ignore_plotter_warnings() -> None: adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) - if _SEABORN_VERS <= '0.13.1': - warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas - message="is_categorical_dtype is deprecated and will be removed in a future version.") + if _SEABORN_VERS <= "0.13.1": + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="seaborn", # but actually comes from pandas + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) -def _add_groupby_desc_column(results_df: pandas.DataFrame, - groupby_columns: Optional[List[str]] = None, - ) -> Tuple[pandas.DataFrame, List[str], str]: +def _add_groupby_desc_column( + results_df: pandas.DataFrame, + groupby_columns: Optional[List[str]] = None, +) -> Tuple[pandas.DataFrame, List[str], str]: """ Adds a group descriptor column to the results_df. @@ -67,17 +72,19 @@ def _add_groupby_desc_column(results_df: pandas.DataFrame, if groupby_columns is None: groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"] groupby_column = ",".join(groupby_columns) - results_df[groupby_column] = results_df[groupby_columns].astype(str).apply( - lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda + results_df[groupby_column] = ( + results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1) + ) # pylint: disable=unnecessary-lambda groupby_columns.append(groupby_column) return (results_df, groupby_columns, groupby_column) -def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - requested_result_cols: Optional[Iterable[str]] = None, - ) -> pandas.DataFrame: +def augment_results_df_with_config_trial_group_stats( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + requested_result_cols: Optional[Iterable[str]] = None, +) -> pandas.DataFrame: # pylint: disable=too-complex """ Add a number of useful statistical measure columns to the results dataframe. @@ -134,30 +141,46 @@ def augment_results_df_with_config_trial_group_stats(exp_data: Optional[Experime raise ValueError(f"Not enough data: {len(results_groups)}") if requested_result_cols is None: - result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)) + result_cols = set( + col + for col in results_df.columns + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + ) else: - result_cols = set(col for col in requested_result_cols - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns) - result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols - if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns)) + result_cols = set( + col + for col in requested_result_cols + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns + ) + result_cols.update( + set( + ExperimentData.RESULT_COLUMN_PREFIX + col + for col in requested_result_cols + if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns + ) + ) def compute_zscore_for_group_agg( - results_groups_perf: "SeriesGroupBy", - stats_df: pandas.DataFrame, - result_col: str, - agg: Union[Literal["mean"], Literal["var"], Literal["std"]] + results_groups_perf: "SeriesGroupBy", + stats_df: pandas.DataFrame, + result_col: str, + agg: Union[Literal["mean"], Literal["var"], Literal["std"]], ) -> None: - results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? + results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() - stats_df[result_col + f".{agg}_zscore"] = \ - (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \ - / stats_df[result_col + f".{agg}_stddev"] - stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True) + stats_df[result_col + f".{agg}_zscore"] = ( + stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"] + ) / stats_df[result_col + f".{agg}_stddev"] + stats_df.drop( + columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True + ) augmented_results_df = results_df - augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count") + augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform( + "count" + ) for result_col in result_cols: if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX): continue @@ -176,20 +199,21 @@ def compute_zscore_for_group_agg( compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var") quantiles = [0.50, 0.75, 0.90, 0.95, 0.99] - for quantile in quantiles: # TODO: can we do this in one pass? + for quantile in quantiles: # TODO: can we do this in one pass? quantile_col = f"{result_col}.p{int(quantile * 100)}" stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile) augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1) return augmented_results_df -def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - top_n_configs: int = 10, - method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", - ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: +def limit_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + top_n_configs: int = 10, + method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", +) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ Utility function to process the results and determine the best performing @@ -219,7 +243,9 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, raise ValueError(f"Invalid method: {method}") # Prepare the orderby columns. - (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives) + (results_df, objs_cols) = expand_results_data_args( + exp_data, results_df=results_df, objectives=objectives + ) assert isinstance(results_df, pandas.DataFrame) # Augment the results dataframe with some useful stats. @@ -232,13 +258,17 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # results_df is not None and is in fact a DataFrame, so we periodically assert # it in this func for now. assert results_df is not None - orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()} + orderby_cols: Dict[str, bool] = { + obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items() + } config_id_col = "tunable_config_id" - group_id_col = "tunable_config_trial_group_id" # first trial_id per config group + group_id_col = "tunable_config_trial_group_id" # first trial_id per config group trial_id_col = "trial_id" - default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id + default_config_id = ( + results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id + ) assert default_config_id is not None, "Failed to determine default config id." # Filter out configs whose variance is too large. @@ -250,16 +280,18 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, singletons_mask = results_df["tunable_config_trial_group_size"] == 1 else: singletons_mask = results_df["tunable_config_trial_group_size"] > 1 - results_df = results_df.loc[( - (results_df[f"{obj_col}.var_zscore"].abs() < 2) - | (singletons_mask) - | (results_df[config_id_col] == default_config_id) - )] + results_df = results_df.loc[ + ( + (results_df[f"{obj_col}.var_zscore"].abs() < 2) + | (singletons_mask) + | (results_df[config_id_col] == default_config_id) + ) + ] assert results_df is not None # Also, filter results that are worse than the default. default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id] - for (orderby_col, ascending) in orderby_cols.items(): + for orderby_col, ascending in orderby_cols.items(): default_vals = default_config_results_df[orderby_col].unique() assert len(default_vals) == 1 default_val = default_vals[0] @@ -271,29 +303,38 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # Now regroup and filter to the top-N configs by their group performance dimensions. assert results_df is not None - group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()] - top_n_config_ids: List[int] = group_results_df.sort_values( - by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist() + group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[ + orderby_cols.keys() + ] + top_n_config_ids: List[int] = ( + group_results_df.sort_values( + by=list(orderby_cols.keys()), ascending=list(orderby_cols.values()) + ) + .head(top_n_configs) + .index.tolist() + ) # Remove the default config if it's included. We'll add it back later. if default_config_id in top_n_config_ids: top_n_config_ids.remove(default_config_id) # Get just the top-n config results. # Sort by the group ids. - top_n_config_results_df = results_df.loc[( - results_df[config_id_col].isin(top_n_config_ids) - )].sort_values([group_id_col, config_id_col, trial_id_col]) + top_n_config_results_df = results_df.loc[ + (results_df[config_id_col].isin(top_n_config_ids)) + ].sort_values([group_id_col, config_id_col, trial_id_col]) # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) - top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0) + top_n_config_results_df = pandas.concat( + [default_config_results_df, top_n_config_results_df], axis=0 + ) return (top_n_config_results_df, top_n_config_ids, orderby_cols) def plot_optimizer_trends( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, ) -> None: """ Plots the optimizer trends for the Experiment. @@ -312,12 +353,16 @@ def plot_optimizer_trends( (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) (results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df) - for (objective_column, ascending) in obj_cols.items(): + for objective_column, ascending in obj_cols.items(): incumbent_column = objective_column + ".incumbent" # Determine the mean of each config trial group to match the box plots. - group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\ - .reset_index().sort_values(groupby_columns) + group_results_df = ( + results_df.groupby(groupby_columns)[objective_column] + .mean() + .reset_index() + .sort_values(groupby_columns) + ) # # Note: technically the optimizer (usually) uses the *first* result for a # given config trial group before moving on to a new config (x-axis), so @@ -355,24 +400,29 @@ def plot_optimizer_trends( ax=axis, ) - plt.yscale('log') + plt.yscale("log") plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")) plt.xlabel("Config Trial Group ID, Config ID") plt.xticks(rotation=90, fontsize=8) - plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "") + plt.title( + "Optimizer Trends for Experiment: " + exp_data.experiment_id + if exp_data is not None + else "" + ) plt.grid() plt.show() # type: ignore[no-untyped-call] -def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - with_scatter_plot: bool = False, - **kwargs: Any, - ) -> None: +def plot_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + with_scatter_plot: bool = False, + **kwargs: Any, +) -> None: # pylint: disable=too-many-locals """ Plots the top-N configs along with the default config for the given ExperimentData. @@ -400,12 +450,16 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, top_n_config_args["results_df"] = results_df if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives - (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args) + (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( + exp_data=exp_data, **top_n_config_args + ) - (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df) + (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column( + top_n_config_results_df + ) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 - for (orderby_col, ascending) in orderby_cols.items(): + for orderby_col, ascending in orderby_cols.items(): opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "") (_fig, axis) = plt.subplots() sns.violinplot( @@ -425,12 +479,12 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, plt.grid() (xticks, xlabels) = plt.xticks() # default should be in the first position based on top_n_configs() return - xlabels[0] = "default" # type: ignore[call-overload] - plt.xticks(xticks, xlabels) # type: ignore[arg-type] + xlabels[0] = "default" # type: ignore[call-overload] + plt.xticks(xticks, xlabels) # type: ignore[arg-type] plt.xlabel("Config Trial Group, Config ID") plt.xticks(rotation=90) plt.ylabel(opt_tgt) - plt.yscale('log') + plt.yscale("log") extra_title = "(lower is better)" if ascending else "(lower is better)" plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}") plt.show() # type: ignore[no-untyped-call] diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 504486a58c..beeba3248f 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -15,10 +15,12 @@ from mlos_viz.util import expand_results_data_args -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - ) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, +) -> None: """ Plots the Experiment results data using dabl. @@ -44,17 +46,45 @@ def ignore_plotter_warnings() -> None: """ # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, - message="Missing values in target_col have been removed for regression") + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="Could not infer format" + ) + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers" + ) + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated" + ) + warnings.filterwarnings( + "ignore", + module="dabl", + category=UserWarning, + message="Missing values in target_col have been removed for regression", + ) from sklearn.exceptions import UndefinedMetricWarning - warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined") - warnings.filterwarnings("ignore", category=DeprecationWarning, - message="is_categorical_dtype is deprecated and will be removed in a future version.") - warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn", - message="is_sparse is deprecated and will be removed in a future version.") + + warnings.filterwarnings( + "ignore", + module="sklearn", + category=UndefinedMetricWarning, + message="Recall is ill-defined", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="sklearn", + message="is_sparse is deprecated and will be removed in a future version.", + ) from matplotlib._api.deprecation import MatplotlibDeprecationWarning - warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl", - message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed") + + warnings.filterwarnings( + "ignore", + category=MatplotlibDeprecationWarning, + module="dabl", + message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", + ) diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index 06ac4a7664..e5528f9875 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -30,5 +30,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) -> warnings.simplefilter("error") random.seed(42) plot(exp_data, filter_warnings=True) - assert mock_show.call_count >= 2 # from the two base plots and anything dabl did - assert mock_boxplot.call_count >= 1 # from anything dabl did + assert mock_show.call_count >= 2 # from the two base plots and anything dabl did + assert mock_boxplot.call_count >= 1 # from anything dabl did diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 744fe28648..8f426810f8 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -49,11 +49,14 @@ def expand_results_data_args( raise ValueError("Must provide either exp_data or both results_df and objectives.") objectives = exp_data.objectives objs_cols: Dict[str, bool] = {} - for (opt_tgt, opt_dir) in objectives.items(): + for opt_tgt, opt_dir in objectives.items(): if opt_dir not in ["min", "max"]: raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}") ascending = opt_dir == "min" - if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns: + if ( + opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + and opt_tgt in results_df.columns + ): objs_cols[opt_tgt] = ascending elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns: objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index 607c7cc014..d418ae43c7 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 98d12598e1..638a28469a 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -21,15 +21,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -47,22 +48,22 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } @@ -70,23 +71,23 @@ def _get_long_desc_from_readme(base_url: str) -> dict: # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'mlos-bench==' + VERSION, - 'dabl>=0.2.6', - 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341 + "mlos-bench==" + VERSION, + "dabl>=0.2.6", + "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341 ], extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'), + **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_viz"), ) From af71e046e805d86c7eb52eff75ca2a45f31a6a2e Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:50:08 +0000 Subject: [PATCH 14/54] tweaks for comments for new line length --- doc/source/conf.py | 3 +- .../config/schemas/config_schemas.py | 13 +++- .../mlos_bench/environments/composite_env.py | 3 +- mlos_bench/mlos_bench/launcher.py | 12 ++-- .../mlos_bench/services/base_service.py | 3 +- .../remote/azure/azure_deployment_services.py | 3 +- .../remote/azure/azure_network_services.py | 3 +- .../services/remote/ssh/ssh_fileshare.py | 3 +- .../services/remote/ssh/ssh_host_service.py | 6 +- .../services/remote/ssh/ssh_service.py | 3 +- .../services/types/host_provisioner_type.py | 3 +- .../types/network_provisioner_type.py | 3 +- .../storage/base_experiment_data.py | 3 +- .../mlos_bench/storage/sql/experiment.py | 3 +- .../mlos_bench/storage/sql/experiment_data.py | 3 +- .../test_load_environment_config_examples.py | 6 +- .../config/schemas/cli/test_cli_schemas.py | 10 +-- .../environments/test_environment_schemas.py | 7 +- .../schemas/globals/test_globals_schemas.py | 5 +- .../optimizers/test_optimizer_schemas.py | 3 +- .../schemas/services/test_services_schemas.py | 3 +- .../test_tunable_values_schemas.py | 5 +- .../environments/local/local_env_test.py | 3 +- .../optimizers/grid_search_optimizer_test.py | 3 +- .../tests/services/remote/ssh/__init__.py | 3 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 ++-- .../bayesian_optimizers/smac_optimizer.py | 33 ++++++--- .../mlos_core/optimizers/flaml_optimizer.py | 40 +++++++---- mlos_core/mlos_core/optimizers/optimizer.py | 19 +++-- .../mlos_core/optimizers/random_optimizer.py | 3 +- .../mlos_core/spaces/adapters/adapter.py | 20 ++++-- .../mlos_core/spaces/adapters/llamatune.py | 69 +++++++++++++------ .../tests/optimizers/optimizer_test.py | 15 ++-- .../tests/spaces/adapters/llamatune_test.py | 3 +- mlos_core/setup.py | 4 +- mlos_viz/mlos_viz/base.py | 9 ++- mlos_viz/mlos_viz/util.py | 3 +- 37 files changed, 235 insertions(+), 112 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 3e25d9b082..4567d15a5d 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -89,7 +89,8 @@ autodoc_default_options = { 'members': True, 'undoc-members': True, - # Don't generate documentation for some (non-private) functions that are more for internal implementation use. + # Don't generate documentation for some (non-private) functions that are + # more for internal implementation use. 'exclude-members': 'mlos_bench.util.check_required_params' } diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 181f96e5d6..56ea8b7879 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -3,7 +3,8 @@ # Licensed under the MIT License. # """ -A simple class for describing where to find different config schemas and validating configs against them. +A simple class for describing where to find different config schemas and +validating configs against them. """ import json # schema files are pure json - no comments @@ -62,7 +63,10 @@ def __getitem__(self, key: str) -> dict: @classmethod def _load_schemas(cls) -> None: - """Loads all schemas and subschemas into the schema store for the validator to reference.""" + """ + Loads all schemas and subschemas into the schema store for the + validator to reference. + """ if cls._SCHEMA_STORE: return for root, _, files in walk(CONFIG_SCHEMA_DIR): @@ -82,7 +86,10 @@ def _load_schemas(cls) -> None: @classmethod def _load_registry(cls) -> None: - """Also store them in a Registry object for referencing by recent versions of jsonschema.""" + """ + Also store them in a Registry object for referencing by recent versions + of jsonschema. + """ if not cls._SCHEMA_STORE: cls._load_schemas() cls._REGISTRY = Registry().with_resources( diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 36ab99a223..72bb799a0e 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -254,5 +254,6 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: final_status = final_status or status _LOG.info("Final status: %s :: %s", self, final_status) - # Return the status and the timestamp of the last child environment or the first failed child environment. + # Return the status and the timestamp of the last child environment or + # the first failed child environment. return (final_status, timestamp, joint_telemetry) diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index d988e370b3..298cdf65c9 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -48,10 +48,12 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # pylint: disable=too-many-statements _LOG.info("Launch: %s", description) epilog = """ - Additional --key=value pairs can be specified to augment or override values listed in --globals. + Additional --key=value pairs can be specified to augment or + override values listed in --globals. Other required_args values can also be pulled from shell environment variables. - For additional details, please see the website or the README.md files in the source tree: + For additional details, please see the website or the README.md + files in the source tree: """ parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) @@ -92,11 +94,13 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st args_rest, {key: val for (key, val) in config.items() if key not in vars(args)}, ) - # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. + # experiment_id is generally taken from --globals files, but we also + # allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: self.global_config["experiment_id"] = args.experiment_id - # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line + # trial_config_repeat_count is a scheduler property but it's convenient + # to set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 65725b6288..316aef2feb 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -210,7 +210,8 @@ def _validate_json_config(self, config: dict) -> None: file loading mechanism. """ if self.__class__ == Service: - # Skip over the case where instantiate a bare base Service class in order to build up a mix-in. + # Skip over the case where instantiate a bare base Service class in + # order to build up a mix-in. assert config == {} return json_config: dict = { diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 3673baca76..7f779ff830 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -277,7 +277,8 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for resource being deployed; otherwise, wait for successful deprovisioning. + If True, wait for resource being deployed; otherwise, wait for + successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 4ba8bd3903..fb630eb1de 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -30,7 +30,8 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # Azure Compute REST API calls as described in # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 - # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 + # From: + # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( "https://management.azure.com" + "/subscriptions/{subscription}" diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f136747f7f..99899b6917 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -46,7 +46,8 @@ async def _start_file_copy( Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for + establishing the connection). mode : CopyMode Whether to download or upload the file. local_path : str diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index f04544eb05..0bb5cf16dd 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -79,7 +79,8 @@ async def _run_cmd( Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for + establishing the connection). cmd : str Command(s) to run via shell. @@ -92,7 +93,8 @@ async def _run_cmd( # Script should be an iterable of lines, not an iterable string. script = [script] connection, _ = await self._get_client_connection(params) - # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. + # Note: passing environment variables to SSH servers is typically + # restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] script_lines = env_script_lines + [ diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 64bb7d9788..b960a84deb 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -274,7 +274,8 @@ def __init__( # available can confuse some commands, though we may need to make # this configurable in the future. "request_pty": False, - # By default disable known_hosts checking (since most VMs expected to be dynamically created). + # By default disable known_hosts checking (since most VMs expected + # to be dynamically created). "known_hosts": None, } diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 77b481e48e..3d80055197 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -46,7 +46,8 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning. + If True, wait for Host/VM being deployed; otherwise, wait for + successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 50b24cc4b8..27232b54cd 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -46,7 +46,8 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Network being deployed; otherwise, wait for successful deprovisioning. + If True, wait for Network being deployed; otherwise, wait for + successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index 47581f0725..eadbb91fd9 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -162,7 +162,8 @@ def results_df(self) -> pandas.DataFrame: ------- results : pandas.DataFrame A DataFrame with configurations and results from all trials of the experiment. - Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] + Has columns + [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] followed by tunable config parameters (prefixed with "config.") and trial results (prefixed with "result."). The latter can be NULLs if the trial was not successful. diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index e6322c7ade..3024846c72 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -218,7 +218,8 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D .select_from(table) .where(*[column(key) == val for (key, val) in kwargs.items()]) ) - # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. + # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to + # avoid naming conflicts. return dict( row._tuple() for row in cur_result.fetchall() ) # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index f299bcff68..48f9303c59 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -74,7 +74,8 @@ def objectives(self) -> Dict[str, Literal["min", "max"]]: for objective in objectives_db_data.fetchall() } - # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed. + # TODO: provide a way to get individual data to avoid repeated bulk fetches + # where only small amounts of data is accessed. # Or else make the TrialData object lazily populate. @property diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 2369b0c27a..71a6741106 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -125,8 +125,10 @@ def test_load_composite_env_config_examples( assert child_group is composite_group checked_child_env_groups.add(child_group.name) - # Check that when we change a child env, it's value is reflected in the composite env as well. - # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs. + # Check that when we change a child env, it's value is reflected in + # the composite env as well. + # That is to say, they refer to the same objects, despite having + # potentially been loaded from separate configs. if child_tunable.is_categorical: old_cat_value = child_tunable.category assert child_tunable.value == old_cat_value diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 32ea0b9713..5d97ca01c5 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -35,8 +35,9 @@ def test_cli_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat + # dicts with minor constraints on them, so adding/removing params + # doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @@ -50,6 +51,7 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat + # dicts with minor constraints on them, so adding/removing params + # doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 1528d8d164..5ce1c0e727 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,7 +33,9 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + # ScriptEnv is ABCMeta abstract, but there's no good way to test that + # dynamically in Python. + ScriptEnv } expected_environment_class_names = [ subclass.__module__ + "." + subclass.__name__ @@ -80,7 +82,8 @@ def test_environment_configs_against_schema(test_case_name: str) -> None: @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ - Checks that the environment config fails to validate if extra params are present in certain places. + Checks that the environment config fails to validate if extra params are + present in certain places. """ check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 508787a84b..2c485b7e30 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -34,6 +34,7 @@ def test_globals_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat + # dicts with minor constraints on them, so adding/removing params + # doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index ef5c0edfa3..9e5d9d72d1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -41,7 +41,8 @@ ] assert expected_mlos_bench_optimizer_class_names -# Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). +# Also make sure that we check for configs where the optimizer_type or +# space_adapter_type are left unspecified (None). expected_mlos_core_optimizer_types = list(OptimizerType) + [None] assert expected_mlos_core_optimizer_types diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 032b4c0aad..4d41600d34 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -39,7 +39,8 @@ NON_CONFIG_SERVICE_CLASSES = { ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + # ABCMeta abstract class, but no good way to test that dynamically in Python. + TempDirContextService, AzureDeploymentService, # ABCMeta abstract base class SshService, # ABCMeta abstract base class } diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 33124134e9..77f1d776ea 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -34,6 +34,7 @@ def test_tunable_values_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat + # dicts with minor constraints on them, so adding/removing params + # doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 2b51ae1f0e..d72036fbf5 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -43,7 +43,8 @@ def test_local_env(tunable_groups: TunableGroups) -> None: def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ - Basic check that context support for Service mixins are handled when environment contexts are entered. + Basic check that context support for Service mixins are handled when + environment contexts are entered. """ local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index add2945d74..80b6cd148b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -192,7 +192,8 @@ def test_grid_search( assert not list(grid_search_opt.suggested_configs) assert not grid_search_opt.not_converged() - # But if we still have iterations left, we should be able to suggest again by refilling the grid. + # But if we still have iterations left, we should be able to suggest again + # by refilling the grid. assert grid_search_opt.current_iteration < grid_search_opt.max_iterations assert grid_search_opt.suggest() assert list(grid_search_opt.pending_configs) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index 16c88dc791..b25b7c0534 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -39,7 +39,8 @@ def get_port(self, uncached: bool = False) -> int: """ Gets the port that the SSH test server is listening on. - Note: this value can change when the service restarts so we can't rely on the DockerServices. + Note: this value can change when the service restarts so we can't rely + on the DockerServices. """ if self._port is None or uncached: port_cmd = run( diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 9d3bcabcb2..8f81282553 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -22,12 +22,15 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): def surrogate_predict( self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None ) -> npt.NDArray: - """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). + """ + Obtain a prediction from this Bayesian optimizer's surrogate model for + the given configuration(s). Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. context : pd.DataFrame Not Yet Implemented. @@ -38,12 +41,15 @@ def surrogate_predict( def acquisition_function( self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None ) -> npt.NDArray: - """Invokes the acquisition function from this Bayesian optimizer for the given configuration. + """ + Invokes the acquisition function from this Bayesian optimizer for the + given configuration. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. context : pd.DataFrame Not Yet Implemented. diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 5784a42f12..362b9b6ce6 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -64,18 +64,21 @@ def __init__( seed : Optional[int] By default SMAC uses a known seed (0) to keep results reproducible. - However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. + However, if a `None` seed is explicitly provided, we let a random + seed be produced by SMAC. run_name : Optional[str] Name of this run. This is used to easily distinguish across different runs. If set to `None` (default), SMAC will generate a hash from metadata. output_directory : Optional[str] - The directory where SMAC output will saved. If set to `None` (default), a temporary dir will be used. + The directory where SMAC output will saved. If set to `None` + (default), a temporary dir will be used. max_trials : int Maximum number of trials (i.e., function evaluations) to be run. Defaults to 100. - Note that modifying this value directly affects the value of `n_random_init`, if latter is set to `None`. + Note that modifying this value directly affects the value of + `n_random_init`, if latter is set to `None`. n_random_init : Optional[int] Number of points evaluated at start to bootstrap the optimizer. @@ -119,7 +122,8 @@ def __init__( self.trial_info_map: Dict[ConfigSpace.Configuration, TrialInfo] = {} # The default when not specified is to use a known seed (0) to keep results reproducible. - # However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. + # However, if a `None` seed is explicitly provided, we let a random + # seed be produced by SMAC. # https://automl.github.io/SMAC3/main/api/smac.scenario.html#smac.scenario.Scenario seed = -1 if seed is None else seed @@ -224,9 +228,11 @@ def __del__(self) -> None: @property def n_random_init(self) -> int: """ - Gets the number of random samples to use to initialize the optimizer's search space sampling. + Gets the number of random samples to use to initialize the optimizer's + search space sampling. - Note: This may not be equal to the value passed to the initializer, due to logic present in the SMAC. + Note: This may not be equal to the value passed to the initializer, due + to logic present in the SMAC. See Also: max_ratio Returns @@ -251,8 +257,10 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None seed : int Random seed to use for the target function. Not actually used. """ - # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API - # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 + # NOTE: Providing a target function when using the ask-and-tell + # interface is an imperfection of the API -- this is planned to be + # fixed in some future release: + # https://github.com/automl/SMAC3/issues/946 raise RuntimeError("This function should never be called.") def _register( @@ -268,7 +276,8 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -292,7 +301,8 @@ def _register( for config, (_i, score) in zip( self._to_configspace_configs(configs=configs), scores.iterrows() ): - # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance + # Retrieve previously generated TrialInfo (returned by .ask()) or + # create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) ) @@ -404,7 +414,8 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. Returns ------- diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 2df19b8eb2..164ce40bb3 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -30,7 +30,8 @@ class FlamlOptimizer(BaseOptimizer): Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning. """ - # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. + # The name of an internal objective attribute that is calculated as a + # weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" def __init__( @@ -62,10 +63,12 @@ def __init__( low_cost_partial_config : dict A dictionary from a subset of controlled dimensions to the initial low-cost values. - More info: https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune + More info: + https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune seed : Optional[int] - If provided, calls np.random.seed() with the provided value to set the seed globally at init. + If provided, calls np.random.seed() with the provided value to set + the seed globally at init. """ super().__init__( parameter_space=parameter_space, @@ -106,7 +109,8 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -170,19 +174,23 @@ def register_pending( def _target_function(self, config: dict) -> Union[dict, None]: """Configuration evaluation function called by FLAML optimizer. - FLAML may suggest the same configuration multiple times (due to its warm-start mechanism). - Once FLAML suggests an unseen configuration, we store it, and stop the optimization process. + FLAML may suggest the same configuration multiple times (due to its + warm-start mechanism). + Once FLAML suggests an unseen configuration, we store it, and stop the + optimization process. Parameters ---------- config: dict Next configuration to be evaluated, as suggested by FLAML. - This config is stored internally and is returned to user, via `.suggest()` method. + This config is stored internally and is returned to user, via + `.suggest()` method. Returns ------- result: Union[dict, None] - Dictionary with a single key, `FLAML_score`, if config already evaluated; `None` otherwise. + Dictionary with a single key, `FLAML_score`, if config already + evaluated; `None` otherwise. """ cs_config = normalize_config(self.optimizer_parameter_space, config) if cs_config in self.evaluated_samples: @@ -192,12 +200,16 @@ def _target_function(self, config: dict) -> Union[dict, None]: return None # Returning None stops the process def _get_next_config(self) -> dict: - """Warm-starts a new instance of FLAML, and returns a recommended, unseen new configuration. - - Since FLAML does not provide an ask-and-tell interface, we need to create a new instance of FLAML - each time we get asked for a new suggestion. This is suboptimal performance-wise, but works. - To do so, we use any previously evaluated configs to bootstrap FLAML (i.e., warm-start). - For more info: https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start + """ + Warm-starts a new instance of FLAML, and returns a recommended, unseen + new configuration. + + Since FLAML does not provide an ask-and-tell interface, we need to + create a new instance of FLAML each time we get asked for a new + suggestion. This is suboptimal performance-wise, but works. To do so, + we use any previously evaluated configs to bootstrap FLAML (i.e., + warm-start). For more info: + https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start Returns ------- diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index f96bce7075..0f600c76bd 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -80,12 +80,15 @@ def register( context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None, ) -> None: - """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. + """ + Wrapper method, which employs the space adapter (if any), before + registering the configs and scores. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -133,7 +136,8 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -146,7 +150,8 @@ def suggest( self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ - Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. + Wrapper method, which employs the space adapter (if any), after + suggesting a new configuration. Parameters ---------- @@ -209,13 +214,15 @@ def register_pending( metadata: Optional[pd.DataFrame] = None, ) -> None: """Registers the given configs as "pending". - That is it say, it has been suggested by the optimizer, and an experiment trial has been started. + That is it say, it has been suggested by the optimizer, and an + experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. context : pd.DataFrame Not Yet Implemented. metadata : Optional[pd.DataFrame] diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index bf6f85ff88..b5e1de93fc 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -39,7 +39,8 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names + and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 58d07763f6..5f45414e1a 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -50,13 +50,16 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: - """Translates a configuration, which belongs to the target parameter space, to the original parameter space. + """ + Translates a configuration, which belongs to the target parameter + space, to the original parameter space. This method is called by the `suggest` method of the `BaseOptimizer` class. Parameters ---------- configuration : pd.DataFrame - Pandas dataframe with a single row. Column names are the parameter names of the target parameter space. + Pandas dataframe with a single row. Column names are the parameter + names of the target parameter space. Returns ------- @@ -68,20 +71,25 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: - """Translates a configuration, which belongs to the original parameter space, to the target parameter space. - This method is called by the `register` method of the `BaseOptimizer` class, and performs the inverse operation + """ + Translates a configuration, which belongs to the original parameter + space, to the target parameter space. + This method is called by the `register` method of the `BaseOptimizer` + class, and performs the inverse operation of `BaseSpaceAdapter.transform` method. Parameters ---------- configurations : pd.DataFrame Dataframe of configurations / parameters, which belong to the original parameter space. - The columns are the parameter names the original parameter space and the rows are the configurations. + The columns are the parameter names the original parameter space + and the rows are the configurations. Returns ------- configurations : pd.DataFrame Dataframe of the translated configurations / parameters. - The columns are the parameter names of the target parameter space and the rows are the configurations. + The columns are the parameter names of the target parameter space + and the rows are the configurations. """ pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index b8abdedfeb..4fa459fd5d 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -26,13 +26,21 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-a """ DEFAULT_NUM_LOW_DIMS = 16 - """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" + """ + Default number of dimensions in the low-dimensional search space, generated + by HeSBO projection + """ DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 - """Default percentage of bias for each special parameter value""" + """ + Default percentage of bias for each special parameter value + """ DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 - """Default number of (max) unique values of each parameter, when space discretization is used""" + """ + Default number of (max) unique values of each parameter, when space + discretization is used + """ def __init__( self, @@ -101,11 +109,15 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: ) target_config = self._suggested_configs.get(configuration, None) - # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection - # To (partly) support this operation, we keep track of the suggested low-dim point(s) along with the - # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart. + # NOTE: HeSBO is a non-linear projection method, and does not + # inherently support inverse projection. + # To (partly) support this operation, we keep track of the + # suggested low-dim point(s) along with the respective high-dim + # point; this way we can retrieve the low-dim point, from its + # high-dim counterpart. if target_config is None: - # Inherently it is not supported to register points, which were not suggested by the optimizer. + # Inherently it is not supported to register points, which were + # not suggested by the optimizer. if configuration == self.orig_parameter_space.get_default_configuration(): # Default configuration should always be registerable. pass @@ -117,7 +129,8 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: "previously by the optimizer can be registered." ) - # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. + # ...yet, we try to support that by implementing an approximate + # reverse mapping using pseudo-inverse matrix. if getattr(self, "_pinv_matrix", None) is None: self._try_generate_approx_inverse_mapping() @@ -181,9 +194,12 @@ def _construct_low_dim_space( for idx in range(num_low_dims) ] else: - # Currently supported optimizers do not support defining a discretized space (like ConfigSpace does using `q` kwarg). - # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. - # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. + # Currently supported optimizers do not support defining a + # discretized space (like ConfigSpace does using `q` kwarg). + # Thus, to support space discretization, we define the + # low-dimensional space using integer hyperparameters. + # We also employ a scaler, which scales suggested values to [-1, 1] + # range, used by HeSBO projection. hyperparameters = [ ConfigSpace.UniformIntegerHyperparameter( name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param @@ -191,7 +207,8 @@ def _construct_low_dim_space( for idx in range(num_low_dims) ] - # Initialize quantized values scaler: from [0, max_unique_values_per_param] to (-1, 1) range + # Initialize quantized values scaler: from [0, + # max_unique_values_per_param] to (-1, 1) range q_scaler = MinMaxScaler(feature_range=(-1, 1)) ones_vector = np.ones(num_low_dims) max_value_vector = ones_vector * max_unique_values_per_param @@ -208,8 +225,10 @@ def _construct_low_dim_space( self._target_config_space = config_space def _transform(self, configuration: dict) -> dict: - """Projects a low-dimensional point (configuration) to the high-dimensional original parameter space, - and then biases the resulting parameter values towards their special value(s) (if any). + """ + Projects a low-dimensional point (configuration) to the + high-dimensional original parameter space, and then biases the + resulting parameter values towards their special value(s) (if any). Parameters ---------- @@ -266,7 +285,9 @@ def _transform(self, configuration: dict) -> dict: def _special_param_value_scaler( self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float ) -> float: - """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. + """ + Biases the special value(s) of this parameter, by shifting the + normalized `input_value` towards those. Parameters ---------- @@ -344,7 +365,8 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non elif all( isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value ): - # User specifies list of tuples; each tuple defines the special value and the biasing percentage + # User specifies list of tuples; each tuple defines the + # special value and the biasing percentage tuple_list = value else: raise ValueError( @@ -395,11 +417,16 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non self._special_param_values_dict = sanitized_dict def _try_generate_approx_inverse_mapping(self) -> None: - """Tries to generate an approximate reverse mapping: i.e., from high-dimensional space to the low-dimensional one. - Reverse mapping is generated using the pseudo-inverse matrix, of original HeSBO projection matrix. - This mapping can be potentially used to register configurations that were *not* previously suggested by the optimizer. - - NOTE: This method is experimental, and there is currently no guarantee that it works as expected. + """ + Tries to generate an approximate reverse mapping: i.e., from + high-dimensional space to the low-dimensional one. + Reverse mapping is generated using the pseudo-inverse matrix, of + original HeSBO projection matrix. + This mapping can be potentially used to register configurations that + were *not* previously suggested by the optimizer. + + NOTE: This method is experimental, and there is currently no guarantee + that it works as expected. Raises ------ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index d5d00d0692..0cda96e72b 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -85,8 +85,10 @@ def test_basic_interface_toy_problem( if kwargs is None: kwargs = {} if optimizer_class == OptimizerType.SMAC.value: - # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. - # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. + # SMAC sets the initial random samples as a percentage of the max + # iterations, which defaults to 100. + # To avoid having to train more than 25 model iterations, we set a + # lower number of max iterations. kwargs["max_trials"] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: @@ -136,7 +138,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) - # It would be better to put this into bayesian_optimizer_test but then we'd have to refit the model + # It would be better to put this into bayesian_optimizer_test but then we'd + # have to refit the model if isinstance(optimizer, BaseBayesianOptimizer): pred_best = optimizer.surrogate_predict(configs=best_config) assert pred_best.shape == (1,) @@ -322,7 +325,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation - # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that + # LlamaTune's optimizer score should better (i.e., lower) than plain + # optimizer's one, or close to that assert ( best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] @@ -378,7 +382,8 @@ def test_mixed_numerics_type_input_space_types( optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] ) -> None: """ - Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. + Toy problem to test the optimizers with mixed numeric types to ensure that + original dtypes are retained. """ max_iterations = 10 if kwargs is None: diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index cd1b250ab7..44e23a02a6 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -562,7 +562,8 @@ def test_deterministic_behavior_for_same_seed( num_target_space_dims: int, param_space_kwargs: dict ) -> None: """ - Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. + Tests LlamaTune's space adapter deterministic behavior when given same seed + in the input parameter space. """ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 4d895db315..e33559032f 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -90,7 +90,9 @@ def _get_long_desc_from_readme(base_url: str) -> dict: version=VERSION, install_requires=[ "scikit-learn>=1.2", - "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released + # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, + # which isn't currently released + "joblib>=1.1.1", "scipy>=1.3.2", "numpy>=1.24", "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index d2fc4edad7..357e73f1f3 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -167,7 +167,8 @@ def compute_zscore_for_group_agg( agg: Union[Literal["mean"], Literal["var"], Literal["std"]], ) -> None: results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? - # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. + # Compute the zscore of the chosen aggregate performance of each group + # into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() stats_df[result_col + f".{agg}_zscore"] = ( @@ -226,12 +227,14 @@ def limit_top_n_configs( results_df : Optional[pandas.DataFrame] The results dataframe to augment, by default None to use the results_df property. objectives : Iterable[str], optional - Which result column(s) to use for sorting the configs, and in which direction ("min" or "max"). + Which result column(s) to use for sorting the configs, and in which + direction ("min" or "max"). By default None to automatically select the experiment objectives. top_n_configs : int, optional How many configs to return, including the default, by default 20. method: Literal["mean", "median", "p50", "p75", "p90", "p95", "p99"] = "mean", - Which statistical method to use when sorting the config groups before determining the cutoff, by default "mean". + Which statistical method to use when sorting the config groups before + determining the cutoff, by default "mean". Returns ------- diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 8f426810f8..b4fb789c00 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -36,7 +36,8 @@ def expand_results_data_args( Returns ------- Tuple[pandas.DataFrame, Dict[str, bool]] - The results dataframe and the objectives columns in the dataframe, plus whether or not they are in ascending order. + The results dataframe and the objectives columns in the dataframe, plus + whether or not they are in ascending order. """ # Prepare the orderby columns. if results_df is None: From d68712be765cc556e1627ce43302b543fda22c08 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:50:31 +0000 Subject: [PATCH 15/54] Revert "tweaks for comments for new line length" This reverts commit af71e046e805d86c7eb52eff75ca2a45f31a6a2e. --- doc/source/conf.py | 3 +- .../config/schemas/config_schemas.py | 13 +--- .../mlos_bench/environments/composite_env.py | 3 +- mlos_bench/mlos_bench/launcher.py | 12 ++-- .../mlos_bench/services/base_service.py | 3 +- .../remote/azure/azure_deployment_services.py | 3 +- .../remote/azure/azure_network_services.py | 3 +- .../services/remote/ssh/ssh_fileshare.py | 3 +- .../services/remote/ssh/ssh_host_service.py | 6 +- .../services/remote/ssh/ssh_service.py | 3 +- .../services/types/host_provisioner_type.py | 3 +- .../types/network_provisioner_type.py | 3 +- .../storage/base_experiment_data.py | 3 +- .../mlos_bench/storage/sql/experiment.py | 3 +- .../mlos_bench/storage/sql/experiment_data.py | 3 +- .../test_load_environment_config_examples.py | 6 +- .../config/schemas/cli/test_cli_schemas.py | 10 ++- .../environments/test_environment_schemas.py | 7 +- .../schemas/globals/test_globals_schemas.py | 5 +- .../optimizers/test_optimizer_schemas.py | 3 +- .../schemas/services/test_services_schemas.py | 3 +- .../test_tunable_values_schemas.py | 5 +- .../environments/local/local_env_test.py | 3 +- .../optimizers/grid_search_optimizer_test.py | 3 +- .../tests/services/remote/ssh/__init__.py | 3 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 ++-- .../bayesian_optimizers/smac_optimizer.py | 33 +++------ .../mlos_core/optimizers/flaml_optimizer.py | 40 ++++------- mlos_core/mlos_core/optimizers/optimizer.py | 19 ++--- .../mlos_core/optimizers/random_optimizer.py | 3 +- .../mlos_core/spaces/adapters/adapter.py | 20 ++---- .../mlos_core/spaces/adapters/llamatune.py | 69 ++++++------------- .../tests/optimizers/optimizer_test.py | 15 ++-- .../tests/spaces/adapters/llamatune_test.py | 3 +- mlos_core/setup.py | 4 +- mlos_viz/mlos_viz/base.py | 9 +-- mlos_viz/mlos_viz/util.py | 3 +- 37 files changed, 112 insertions(+), 235 deletions(-) diff --git a/doc/source/conf.py b/doc/source/conf.py index 4567d15a5d..3e25d9b082 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -89,8 +89,7 @@ autodoc_default_options = { 'members': True, 'undoc-members': True, - # Don't generate documentation for some (non-private) functions that are - # more for internal implementation use. + # Don't generate documentation for some (non-private) functions that are more for internal implementation use. 'exclude-members': 'mlos_bench.util.check_required_params' } diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 56ea8b7879..181f96e5d6 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -3,8 +3,7 @@ # Licensed under the MIT License. # """ -A simple class for describing where to find different config schemas and -validating configs against them. +A simple class for describing where to find different config schemas and validating configs against them. """ import json # schema files are pure json - no comments @@ -63,10 +62,7 @@ def __getitem__(self, key: str) -> dict: @classmethod def _load_schemas(cls) -> None: - """ - Loads all schemas and subschemas into the schema store for the - validator to reference. - """ + """Loads all schemas and subschemas into the schema store for the validator to reference.""" if cls._SCHEMA_STORE: return for root, _, files in walk(CONFIG_SCHEMA_DIR): @@ -86,10 +82,7 @@ def _load_schemas(cls) -> None: @classmethod def _load_registry(cls) -> None: - """ - Also store them in a Registry object for referencing by recent versions - of jsonschema. - """ + """Also store them in a Registry object for referencing by recent versions of jsonschema.""" if not cls._SCHEMA_STORE: cls._load_schemas() cls._REGISTRY = Registry().with_resources( diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 72bb799a0e..36ab99a223 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -254,6 +254,5 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: final_status = final_status or status _LOG.info("Final status: %s :: %s", self, final_status) - # Return the status and the timestamp of the last child environment or - # the first failed child environment. + # Return the status and the timestamp of the last child environment or the first failed child environment. return (final_status, timestamp, joint_telemetry) diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index 298cdf65c9..d988e370b3 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -48,12 +48,10 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # pylint: disable=too-many-statements _LOG.info("Launch: %s", description) epilog = """ - Additional --key=value pairs can be specified to augment or - override values listed in --globals. + Additional --key=value pairs can be specified to augment or override values listed in --globals. Other required_args values can also be pulled from shell environment variables. - For additional details, please see the website or the README.md - files in the source tree: + For additional details, please see the website or the README.md files in the source tree: """ parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) @@ -94,13 +92,11 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st args_rest, {key: val for (key, val) in config.items() if key not in vars(args)}, ) - # experiment_id is generally taken from --globals files, but we also - # allow overriding it on the CLI. + # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: self.global_config["experiment_id"] = args.experiment_id - # trial_config_repeat_count is a scheduler property but it's convenient - # to set it via command line + # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 316aef2feb..65725b6288 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -210,8 +210,7 @@ def _validate_json_config(self, config: dict) -> None: file loading mechanism. """ if self.__class__ == Service: - # Skip over the case where instantiate a bare base Service class in - # order to build up a mix-in. + # Skip over the case where instantiate a bare base Service class in order to build up a mix-in. assert config == {} return json_config: dict = { diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 7f779ff830..3673baca76 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -277,8 +277,7 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for resource being deployed; otherwise, wait for - successful deprovisioning. + If True, wait for resource being deployed; otherwise, wait for successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index fb630eb1de..4ba8bd3903 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -30,8 +30,7 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # Azure Compute REST API calls as described in # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 - # From: - # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 + # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( "https://management.azure.com" + "/subscriptions/{subscription}" diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index 99899b6917..f136747f7f 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -46,8 +46,7 @@ async def _start_file_copy( Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for - establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). mode : CopyMode Whether to download or upload the file. local_path : str diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index 0bb5cf16dd..f04544eb05 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -79,8 +79,7 @@ async def _run_cmd( Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for - establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). cmd : str Command(s) to run via shell. @@ -93,8 +92,7 @@ async def _run_cmd( # Script should be an iterable of lines, not an iterable string. script = [script] connection, _ = await self._get_client_connection(params) - # Note: passing environment variables to SSH servers is typically - # restricted to just some LC_* values. + # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] script_lines = env_script_lines + [ diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index b960a84deb..64bb7d9788 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -274,8 +274,7 @@ def __init__( # available can confuse some commands, though we may need to make # this configurable in the future. "request_pty": False, - # By default disable known_hosts checking (since most VMs expected - # to be dynamically created). + # By default disable known_hosts checking (since most VMs expected to be dynamically created). "known_hosts": None, } diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 3d80055197..77b481e48e 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -46,8 +46,7 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Host/VM being deployed; otherwise, wait for - successful deprovisioning. + If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 27232b54cd..50b24cc4b8 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -46,8 +46,7 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Network being deployed; otherwise, wait for - successful deprovisioning. + If True, wait for Network being deployed; otherwise, wait for successful deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index eadbb91fd9..47581f0725 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -162,8 +162,7 @@ def results_df(self) -> pandas.DataFrame: ------- results : pandas.DataFrame A DataFrame with configurations and results from all trials of the experiment. - Has columns - [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] + Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] followed by tunable config parameters (prefixed with "config.") and trial results (prefixed with "result."). The latter can be NULLs if the trial was not successful. diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 3024846c72..e6322c7ade 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -218,8 +218,7 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D .select_from(table) .where(*[column(key) == val for (key, val) in kwargs.items()]) ) - # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to - # avoid naming conflicts. + # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. return dict( row._tuple() for row in cur_result.fetchall() ) # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index 48f9303c59..f299bcff68 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -74,8 +74,7 @@ def objectives(self) -> Dict[str, Literal["min", "max"]]: for objective in objectives_db_data.fetchall() } - # TODO: provide a way to get individual data to avoid repeated bulk fetches - # where only small amounts of data is accessed. + # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed. # Or else make the TrialData object lazily populate. @property diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 71a6741106..2369b0c27a 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -125,10 +125,8 @@ def test_load_composite_env_config_examples( assert child_group is composite_group checked_child_env_groups.add(child_group.name) - # Check that when we change a child env, it's value is reflected in - # the composite env as well. - # That is to say, they refer to the same objects, despite having - # potentially been loaded from separate configs. + # Check that when we change a child env, it's value is reflected in the composite env as well. + # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs. if child_tunable.is_categorical: old_cat_value = child_tunable.category assert child_tunable.value == old_cat_value diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 5d97ca01c5..32ea0b9713 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -35,9 +35,8 @@ def test_cli_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat - # dicts with minor constraints on them, so adding/removing params - # doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, + # so adding/removing params doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @@ -51,7 +50,6 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat - # dicts with minor constraints on them, so adding/removing params - # doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, + # so adding/removing params doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 5ce1c0e727..1528d8d164 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,9 +33,7 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - # ScriptEnv is ABCMeta abstract, but there's no good way to test that - # dynamically in Python. - ScriptEnv + ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. } expected_environment_class_names = [ subclass.__module__ + "." + subclass.__name__ @@ -82,8 +80,7 @@ def test_environment_configs_against_schema(test_case_name: str) -> None: @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ - Checks that the environment config fails to validate if extra params are - present in certain places. + Checks that the environment config fails to validate if extra params are present in certain places. """ check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 2c485b7e30..508787a84b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -34,7 +34,6 @@ def test_globals_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat - # dicts with minor constraints on them, so adding/removing params - # doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, + # so adding/removing params doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index 9e5d9d72d1..ef5c0edfa3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -41,8 +41,7 @@ ] assert expected_mlos_bench_optimizer_class_names -# Also make sure that we check for configs where the optimizer_type or -# space_adapter_type are left unspecified (None). +# Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). expected_mlos_core_optimizer_types = list(OptimizerType) + [None] assert expected_mlos_core_optimizer_types diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 4d41600d34..032b4c0aad 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -39,8 +39,7 @@ NON_CONFIG_SERVICE_CLASSES = { ConfigPersistenceService, # configured thru the launcher cli args - # ABCMeta abstract class, but no good way to test that dynamically in Python. - TempDirContextService, + TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. AzureDeploymentService, # ABCMeta abstract base class SshService, # ABCMeta abstract base class } diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 77f1d776ea..33124134e9 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -34,7 +34,6 @@ def test_tunable_values_configs_against_schema(test_case_name: str) -> None: check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat - # dicts with minor constraints on them, so adding/removing params - # doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, + # so adding/removing params doesn't invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index d72036fbf5..2b51ae1f0e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -43,8 +43,7 @@ def test_local_env(tunable_groups: TunableGroups) -> None: def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ - Basic check that context support for Service mixins are handled when - environment contexts are entered. + Basic check that context support for Service mixins are handled when environment contexts are entered. """ local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 80b6cd148b..add2945d74 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -192,8 +192,7 @@ def test_grid_search( assert not list(grid_search_opt.suggested_configs) assert not grid_search_opt.not_converged() - # But if we still have iterations left, we should be able to suggest again - # by refilling the grid. + # But if we still have iterations left, we should be able to suggest again by refilling the grid. assert grid_search_opt.current_iteration < grid_search_opt.max_iterations assert grid_search_opt.suggest() assert list(grid_search_opt.pending_configs) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index b25b7c0534..16c88dc791 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -39,8 +39,7 @@ def get_port(self, uncached: bool = False) -> int: """ Gets the port that the SSH test server is listening on. - Note: this value can change when the service restarts so we can't rely - on the DockerServices. + Note: this value can change when the service restarts so we can't rely on the DockerServices. """ if self._port is None or uncached: port_cmd = run( diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 8f81282553..9d3bcabcb2 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -22,15 +22,12 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): def surrogate_predict( self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None ) -> npt.NDArray: - """ - Obtain a prediction from this Bayesian optimizer's surrogate model for - the given configuration(s). + """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. context : pd.DataFrame Not Yet Implemented. @@ -41,15 +38,12 @@ def surrogate_predict( def acquisition_function( self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None ) -> npt.NDArray: - """ - Invokes the acquisition function from this Bayesian optimizer for the - given configuration. + """Invokes the acquisition function from this Bayesian optimizer for the given configuration. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. context : pd.DataFrame Not Yet Implemented. diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 362b9b6ce6..5784a42f12 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -64,21 +64,18 @@ def __init__( seed : Optional[int] By default SMAC uses a known seed (0) to keep results reproducible. - However, if a `None` seed is explicitly provided, we let a random - seed be produced by SMAC. + However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. run_name : Optional[str] Name of this run. This is used to easily distinguish across different runs. If set to `None` (default), SMAC will generate a hash from metadata. output_directory : Optional[str] - The directory where SMAC output will saved. If set to `None` - (default), a temporary dir will be used. + The directory where SMAC output will saved. If set to `None` (default), a temporary dir will be used. max_trials : int Maximum number of trials (i.e., function evaluations) to be run. Defaults to 100. - Note that modifying this value directly affects the value of - `n_random_init`, if latter is set to `None`. + Note that modifying this value directly affects the value of `n_random_init`, if latter is set to `None`. n_random_init : Optional[int] Number of points evaluated at start to bootstrap the optimizer. @@ -122,8 +119,7 @@ def __init__( self.trial_info_map: Dict[ConfigSpace.Configuration, TrialInfo] = {} # The default when not specified is to use a known seed (0) to keep results reproducible. - # However, if a `None` seed is explicitly provided, we let a random - # seed be produced by SMAC. + # However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. # https://automl.github.io/SMAC3/main/api/smac.scenario.html#smac.scenario.Scenario seed = -1 if seed is None else seed @@ -228,11 +224,9 @@ def __del__(self) -> None: @property def n_random_init(self) -> int: """ - Gets the number of random samples to use to initialize the optimizer's - search space sampling. + Gets the number of random samples to use to initialize the optimizer's search space sampling. - Note: This may not be equal to the value passed to the initializer, due - to logic present in the SMAC. + Note: This may not be equal to the value passed to the initializer, due to logic present in the SMAC. See Also: max_ratio Returns @@ -257,10 +251,8 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None seed : int Random seed to use for the target function. Not actually used. """ - # NOTE: Providing a target function when using the ask-and-tell - # interface is an imperfection of the API -- this is planned to be - # fixed in some future release: - # https://github.com/automl/SMAC3/issues/946 + # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API + # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 raise RuntimeError("This function should never be called.") def _register( @@ -276,8 +268,7 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -301,8 +292,7 @@ def _register( for config, (_i, score) in zip( self._to_configspace_configs(configs=configs), scores.iterrows() ): - # Retrieve previously generated TrialInfo (returned by .ask()) or - # create new TrialInfo instance + # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) ) @@ -414,8 +404,7 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. Returns ------- diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 164ce40bb3..2df19b8eb2 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -30,8 +30,7 @@ class FlamlOptimizer(BaseOptimizer): Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning. """ - # The name of an internal objective attribute that is calculated as a - # weighted average of the user provided objective metrics. + # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" def __init__( @@ -63,12 +62,10 @@ def __init__( low_cost_partial_config : dict A dictionary from a subset of controlled dimensions to the initial low-cost values. - More info: - https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune + More info: https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune seed : Optional[int] - If provided, calls np.random.seed() with the provided value to set - the seed globally at init. + If provided, calls np.random.seed() with the provided value to set the seed globally at init. """ super().__init__( parameter_space=parameter_space, @@ -109,8 +106,7 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -174,23 +170,19 @@ def register_pending( def _target_function(self, config: dict) -> Union[dict, None]: """Configuration evaluation function called by FLAML optimizer. - FLAML may suggest the same configuration multiple times (due to its - warm-start mechanism). - Once FLAML suggests an unseen configuration, we store it, and stop the - optimization process. + FLAML may suggest the same configuration multiple times (due to its warm-start mechanism). + Once FLAML suggests an unseen configuration, we store it, and stop the optimization process. Parameters ---------- config: dict Next configuration to be evaluated, as suggested by FLAML. - This config is stored internally and is returned to user, via - `.suggest()` method. + This config is stored internally and is returned to user, via `.suggest()` method. Returns ------- result: Union[dict, None] - Dictionary with a single key, `FLAML_score`, if config already - evaluated; `None` otherwise. + Dictionary with a single key, `FLAML_score`, if config already evaluated; `None` otherwise. """ cs_config = normalize_config(self.optimizer_parameter_space, config) if cs_config in self.evaluated_samples: @@ -200,16 +192,12 @@ def _target_function(self, config: dict) -> Union[dict, None]: return None # Returning None stops the process def _get_next_config(self) -> dict: - """ - Warm-starts a new instance of FLAML, and returns a recommended, unseen - new configuration. - - Since FLAML does not provide an ask-and-tell interface, we need to - create a new instance of FLAML each time we get asked for a new - suggestion. This is suboptimal performance-wise, but works. To do so, - we use any previously evaluated configs to bootstrap FLAML (i.e., - warm-start). For more info: - https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start + """Warm-starts a new instance of FLAML, and returns a recommended, unseen new configuration. + + Since FLAML does not provide an ask-and-tell interface, we need to create a new instance of FLAML + each time we get asked for a new suggestion. This is suboptimal performance-wise, but works. + To do so, we use any previously evaluated configs to bootstrap FLAML (i.e., warm-start). + For more info: https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start Returns ------- diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 0f600c76bd..f96bce7075 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -80,15 +80,12 @@ def register( context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None, ) -> None: - """ - Wrapper method, which employs the space adapter (if any), before - registering the configs and scores. + """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -136,8 +133,7 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -150,8 +146,7 @@ def suggest( self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ - Wrapper method, which employs the space adapter (if any), after - suggesting a new configuration. + Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. Parameters ---------- @@ -214,15 +209,13 @@ def register_pending( metadata: Optional[pd.DataFrame] = None, ) -> None: """Registers the given configs as "pending". - That is it say, it has been suggested by the optimizer, and an - experiment trial has been started. + That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. context : pd.DataFrame Not Yet Implemented. metadata : Optional[pd.DataFrame] diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index b5e1de93fc..bf6f85ff88 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -39,8 +39,7 @@ def _register( Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names - and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 5f45414e1a..58d07763f6 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -50,16 +50,13 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: - """ - Translates a configuration, which belongs to the target parameter - space, to the original parameter space. + """Translates a configuration, which belongs to the target parameter space, to the original parameter space. This method is called by the `suggest` method of the `BaseOptimizer` class. Parameters ---------- configuration : pd.DataFrame - Pandas dataframe with a single row. Column names are the parameter - names of the target parameter space. + Pandas dataframe with a single row. Column names are the parameter names of the target parameter space. Returns ------- @@ -71,25 +68,20 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: - """ - Translates a configuration, which belongs to the original parameter - space, to the target parameter space. - This method is called by the `register` method of the `BaseOptimizer` - class, and performs the inverse operation + """Translates a configuration, which belongs to the original parameter space, to the target parameter space. + This method is called by the `register` method of the `BaseOptimizer` class, and performs the inverse operation of `BaseSpaceAdapter.transform` method. Parameters ---------- configurations : pd.DataFrame Dataframe of configurations / parameters, which belong to the original parameter space. - The columns are the parameter names the original parameter space - and the rows are the configurations. + The columns are the parameter names the original parameter space and the rows are the configurations. Returns ------- configurations : pd.DataFrame Dataframe of the translated configurations / parameters. - The columns are the parameter names of the target parameter space - and the rows are the configurations. + The columns are the parameter names of the target parameter space and the rows are the configurations. """ pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 4fa459fd5d..b8abdedfeb 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -26,21 +26,13 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-a """ DEFAULT_NUM_LOW_DIMS = 16 - """ - Default number of dimensions in the low-dimensional search space, generated - by HeSBO projection - """ + """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 - """ - Default percentage of bias for each special parameter value - """ + """Default percentage of bias for each special parameter value""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 - """ - Default number of (max) unique values of each parameter, when space - discretization is used - """ + """Default number of (max) unique values of each parameter, when space discretization is used""" def __init__( self, @@ -109,15 +101,11 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: ) target_config = self._suggested_configs.get(configuration, None) - # NOTE: HeSBO is a non-linear projection method, and does not - # inherently support inverse projection. - # To (partly) support this operation, we keep track of the - # suggested low-dim point(s) along with the respective high-dim - # point; this way we can retrieve the low-dim point, from its - # high-dim counterpart. + # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection + # To (partly) support this operation, we keep track of the suggested low-dim point(s) along with the + # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart. if target_config is None: - # Inherently it is not supported to register points, which were - # not suggested by the optimizer. + # Inherently it is not supported to register points, which were not suggested by the optimizer. if configuration == self.orig_parameter_space.get_default_configuration(): # Default configuration should always be registerable. pass @@ -129,8 +117,7 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: "previously by the optimizer can be registered." ) - # ...yet, we try to support that by implementing an approximate - # reverse mapping using pseudo-inverse matrix. + # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. if getattr(self, "_pinv_matrix", None) is None: self._try_generate_approx_inverse_mapping() @@ -194,12 +181,9 @@ def _construct_low_dim_space( for idx in range(num_low_dims) ] else: - # Currently supported optimizers do not support defining a - # discretized space (like ConfigSpace does using `q` kwarg). - # Thus, to support space discretization, we define the - # low-dimensional space using integer hyperparameters. - # We also employ a scaler, which scales suggested values to [-1, 1] - # range, used by HeSBO projection. + # Currently supported optimizers do not support defining a discretized space (like ConfigSpace does using `q` kwarg). + # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. + # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. hyperparameters = [ ConfigSpace.UniformIntegerHyperparameter( name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param @@ -207,8 +191,7 @@ def _construct_low_dim_space( for idx in range(num_low_dims) ] - # Initialize quantized values scaler: from [0, - # max_unique_values_per_param] to (-1, 1) range + # Initialize quantized values scaler: from [0, max_unique_values_per_param] to (-1, 1) range q_scaler = MinMaxScaler(feature_range=(-1, 1)) ones_vector = np.ones(num_low_dims) max_value_vector = ones_vector * max_unique_values_per_param @@ -225,10 +208,8 @@ def _construct_low_dim_space( self._target_config_space = config_space def _transform(self, configuration: dict) -> dict: - """ - Projects a low-dimensional point (configuration) to the - high-dimensional original parameter space, and then biases the - resulting parameter values towards their special value(s) (if any). + """Projects a low-dimensional point (configuration) to the high-dimensional original parameter space, + and then biases the resulting parameter values towards their special value(s) (if any). Parameters ---------- @@ -285,9 +266,7 @@ def _transform(self, configuration: dict) -> dict: def _special_param_value_scaler( self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float ) -> float: - """ - Biases the special value(s) of this parameter, by shifting the - normalized `input_value` towards those. + """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. Parameters ---------- @@ -365,8 +344,7 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non elif all( isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value ): - # User specifies list of tuples; each tuple defines the - # special value and the biasing percentage + # User specifies list of tuples; each tuple defines the special value and the biasing percentage tuple_list = value else: raise ValueError( @@ -417,16 +395,11 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non self._special_param_values_dict = sanitized_dict def _try_generate_approx_inverse_mapping(self) -> None: - """ - Tries to generate an approximate reverse mapping: i.e., from - high-dimensional space to the low-dimensional one. - Reverse mapping is generated using the pseudo-inverse matrix, of - original HeSBO projection matrix. - This mapping can be potentially used to register configurations that - were *not* previously suggested by the optimizer. - - NOTE: This method is experimental, and there is currently no guarantee - that it works as expected. + """Tries to generate an approximate reverse mapping: i.e., from high-dimensional space to the low-dimensional one. + Reverse mapping is generated using the pseudo-inverse matrix, of original HeSBO projection matrix. + This mapping can be potentially used to register configurations that were *not* previously suggested by the optimizer. + + NOTE: This method is experimental, and there is currently no guarantee that it works as expected. Raises ------ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 0cda96e72b..d5d00d0692 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -85,10 +85,8 @@ def test_basic_interface_toy_problem( if kwargs is None: kwargs = {} if optimizer_class == OptimizerType.SMAC.value: - # SMAC sets the initial random samples as a percentage of the max - # iterations, which defaults to 100. - # To avoid having to train more than 25 model iterations, we set a - # lower number of max iterations. + # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. + # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. kwargs["max_trials"] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: @@ -138,8 +136,7 @@ def objective(x: pd.Series) -> pd.DataFrame: assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) - # It would be better to put this into bayesian_optimizer_test but then we'd - # have to refit the model + # It would be better to put this into bayesian_optimizer_test but then we'd have to refit the model if isinstance(optimizer, BaseBayesianOptimizer): pred_best = optimizer.surrogate_predict(configs=best_config) assert pred_best.shape == (1,) @@ -325,8 +322,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation - # LlamaTune's optimizer score should better (i.e., lower) than plain - # optimizer's one, or close to that + # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that assert ( best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] @@ -382,8 +378,7 @@ def test_mixed_numerics_type_input_space_types( optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] ) -> None: """ - Toy problem to test the optimizers with mixed numeric types to ensure that - original dtypes are retained. + Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ max_iterations = 10 if kwargs is None: diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 44e23a02a6..cd1b250ab7 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -562,8 +562,7 @@ def test_deterministic_behavior_for_same_seed( num_target_space_dims: int, param_space_kwargs: dict ) -> None: """ - Tests LlamaTune's space adapter deterministic behavior when given same seed - in the input parameter space. + Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: diff --git a/mlos_core/setup.py b/mlos_core/setup.py index e33559032f..4d895db315 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -90,9 +90,7 @@ def _get_long_desc_from_readme(base_url: str) -> dict: version=VERSION, install_requires=[ "scikit-learn>=1.2", - # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, - # which isn't currently released - "joblib>=1.1.1", + "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released "scipy>=1.3.2", "numpy>=1.24", "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 357e73f1f3..d2fc4edad7 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -167,8 +167,7 @@ def compute_zscore_for_group_agg( agg: Union[Literal["mean"], Literal["var"], Literal["std"]], ) -> None: results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? - # Compute the zscore of the chosen aggregate performance of each group - # into each row in the dataframe. + # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() stats_df[result_col + f".{agg}_zscore"] = ( @@ -227,14 +226,12 @@ def limit_top_n_configs( results_df : Optional[pandas.DataFrame] The results dataframe to augment, by default None to use the results_df property. objectives : Iterable[str], optional - Which result column(s) to use for sorting the configs, and in which - direction ("min" or "max"). + Which result column(s) to use for sorting the configs, and in which direction ("min" or "max"). By default None to automatically select the experiment objectives. top_n_configs : int, optional How many configs to return, including the default, by default 20. method: Literal["mean", "median", "p50", "p75", "p90", "p95", "p99"] = "mean", - Which statistical method to use when sorting the config groups before - determining the cutoff, by default "mean". + Which statistical method to use when sorting the config groups before determining the cutoff, by default "mean". Returns ------- diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index b4fb789c00..8f426810f8 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -36,8 +36,7 @@ def expand_results_data_args( Returns ------- Tuple[pandas.DataFrame, Dict[str, bool]] - The results dataframe and the objectives columns in the dataframe, plus - whether or not they are in ascending order. + The results dataframe and the objectives columns in the dataframe, plus whether or not they are in ascending order. """ # Prepare the orderby columns. if results_df is None: From 93943c7b3571a01e362e4c0f9c35029fd7abf446 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:50:35 +0000 Subject: [PATCH 16/54] Revert "reformat with black at line length 99" This reverts commit 9c3267cc3e4e323216beb75dfb9a1f139f4eff9d. --- .../fio/scripts/local/process_fio_results.py | 24 +- .../scripts/local/generate_redis_config.py | 12 +- .../scripts/local/process_redis_results.py | 19 +- .../boot/scripts/local/create_new_grub_cfg.py | 10 +- .../scripts/local/generate_grub_config.py | 10 +- .../local/generate_kernel_config_script.py | 5 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 21 +- mlos_bench/mlos_bench/dict_templater.py | 14 +- .../mlos_bench/environments/__init__.py | 15 +- .../environments/base_environment.py | 101 ++--- .../mlos_bench/environments/composite_env.py | 52 +-- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 107 ++--- .../environments/local/local_fileshare_env.py | 67 ++- .../mlos_bench/environments/mock_env.py | 39 +- .../environments/remote/__init__.py | 12 +- .../environments/remote/host_env.py | 29 +- .../environments/remote/network_env.py | 33 +- .../mlos_bench/environments/remote/os_env.py | 36 +- .../environments/remote/remote_env.py | 38 +- .../environments/remote/saas_env.py | 42 +- .../mlos_bench/environments/script_env.py | 40 +- mlos_bench/mlos_bench/event_loop_context.py | 2 +- mlos_bench/mlos_bench/launcher.py | 251 ++++------- mlos_bench/mlos_bench/optimizers/__init__.py | 8 +- .../mlos_bench/optimizers/base_optimizer.py | 81 ++-- .../optimizers/convert_configspace.py | 114 +++-- .../optimizers/grid_search_optimizer.py | 76 ++-- .../optimizers/mlos_core_optimizer.py | 93 ++--- .../mlos_bench/optimizers/mock_optimizer.py | 26 +- .../optimizers/one_shot_optimizer.py | 12 +- .../optimizers/track_best_optimizer.py | 26 +- mlos_bench/mlos_bench/os_environ.py | 11 +- mlos_bench/mlos_bench/run.py | 5 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 92 ++-- .../mlos_bench/schedulers/sync_scheduler.py | 4 +- mlos_bench/mlos_bench/services/__init__.py | 6 +- .../mlos_bench/services/base_fileshare.py | 43 +- .../mlos_bench/services/base_service.py | 53 +-- .../mlos_bench/services/config_persistence.py | 279 +++++-------- .../mlos_bench/services/local/__init__.py | 2 +- .../mlos_bench/services/local/local_exec.py | 45 +- .../services/local/temp_dir_context.py | 19 +- .../services/remote/azure/__init__.py | 10 +- .../services/remote/azure/azure_auth.py | 40 +- .../remote/azure/azure_deployment_services.py | 135 +++--- .../services/remote/azure/azure_fileshare.py | 31 +- .../remote/azure/azure_network_services.py | 72 ++-- .../services/remote/azure/azure_saas.py | 111 +++-- .../remote/azure/azure_vm_services.py | 252 +++++------ .../services/remote/ssh/ssh_fileshare.py | 41 +- .../services/remote/ssh/ssh_host_service.py | 89 ++-- .../services/remote/ssh/ssh_service.py | 126 +++--- .../mlos_bench/services/types/__init__.py | 16 +- .../services/types/config_loader_type.py | 43 +- .../services/types/fileshare_type.py | 8 +- .../services/types/local_exec_type.py | 13 +- .../types/network_provisioner_type.py | 4 +- .../services/types/remote_config_type.py | 3 +- .../services/types/remote_exec_type.py | 5 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 19 +- mlos_bench/mlos_bench/storage/base_storage.py | 115 ++--- .../mlos_bench/storage/base_trial_data.py | 17 +- .../storage/base_tunable_config_data.py | 3 +- .../base_tunable_config_trial_group_data.py | 16 +- mlos_bench/mlos_bench/storage/sql/__init__.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 225 ++++------ .../mlos_bench/storage/sql/experiment.py | 250 +++++------ .../mlos_bench/storage/sql/experiment_data.py | 101 ++--- mlos_bench/mlos_bench/storage/sql/schema.py | 48 +-- mlos_bench/mlos_bench/storage/sql/storage.py | 26 +- mlos_bench/mlos_bench/storage/sql/trial.py | 114 +++-- .../mlos_bench/storage/sql/trial_data.py | 74 ++-- .../storage/sql/tunable_config_data.py | 14 +- .../sql/tunable_config_trial_group_data.py | 41 +- .../mlos_bench/storage/storage_factory.py | 8 +- mlos_bench/mlos_bench/storage/util.py | 18 +- mlos_bench/mlos_bench/tests/__init__.py | 34 +- .../mlos_bench/tests/config/__init__.py | 8 +- .../cli/test_load_cli_config_examples.py | 57 +-- .../mlos_bench/tests/config/conftest.py | 14 +- .../test_load_environment_config_examples.py | 58 +-- .../test_load_global_config_examples.py | 8 +- .../test_load_optimizer_config_examples.py | 8 +- .../tests/config/schemas/__init__.py | 55 +-- .../config/schemas/cli/test_cli_schemas.py | 5 +- .../environments/test_environment_schemas.py | 30 +- .../schemas/globals/test_globals_schemas.py | 1 - .../optimizers/test_optimizer_schemas.py | 58 +-- .../schedulers/test_scheduler_schemas.py | 25 +- .../schemas/services/test_services_schemas.py | 31 +- .../schemas/storage/test_storage_schemas.py | 35 +- .../test_tunable_params_schemas.py | 1 - .../test_tunable_values_schemas.py | 1 - .../test_load_service_config_examples.py | 14 +- .../test_load_storage_config_examples.py | 8 +- mlos_bench/mlos_bench/tests/conftest.py | 16 +- .../mlos_bench/tests/environments/__init__.py | 14 +- .../tests/environments/base_env_test.py | 10 +- .../composite_env_service_test.py | 22 +- .../tests/environments/composite_env_test.py | 143 ++++--- .../environments/include_tunables_test.py | 40 +- .../tests/environments/local/__init__.py | 20 +- .../local/composite_local_env_test.py | 19 +- .../local/local_env_stdout_test.py | 88 ++-- .../local/local_env_telemetry_test.py | 145 +++---- .../environments/local/local_env_test.py | 73 ++-- .../environments/local/local_env_vars_test.py | 57 ++- .../local/local_fileshare_env_test.py | 25 +- .../tests/environments/mock_env_test.py | 64 ++- .../tests/environments/remote/test_ssh_env.py | 18 +- .../tests/event_loop_context_test.py | 57 +-- .../tests/launcher_in_process_test.py | 40 +- .../tests/launcher_parse_args_test.py | 123 +++--- .../mlos_bench/tests/launcher_run_test.py | 93 ++--- .../mlos_bench/tests/optimizers/conftest.py | 48 ++- .../optimizers/grid_search_optimizer_test.py | 105 ++--- .../tests/optimizers/llamatune_opt_test.py | 5 +- .../tests/optimizers/mlos_core_opt_df_test.py | 68 +-- .../optimizers/mlos_core_opt_smac_test.py | 78 ++-- .../tests/optimizers/mock_opt_test.py | 67 ++- .../optimizers/opt_bulk_register_test.py | 101 ++--- .../optimizers/toy_optimization_loop_test.py | 16 +- .../mlos_bench/tests/services/__init__.py | 8 +- .../tests/services/config_persistence_test.py | 29 +- .../tests/services/local/__init__.py | 2 +- .../services/local/local_exec_python_test.py | 9 +- .../tests/services/local/local_exec_test.py | 120 +++--- .../tests/services/local/mock/__init__.py | 2 +- .../local/mock/mock_local_exec_service.py | 23 +- .../mlos_bench/tests/services/mock_service.py | 23 +- .../tests/services/remote/__init__.py | 6 +- .../remote/azure/azure_fileshare_test.py | 142 +++---- .../azure/azure_network_services_test.py | 83 ++-- .../remote/azure/azure_vm_services_test.py | 202 ++++----- .../tests/services/remote/azure/conftest.py | 95 ++--- .../services/remote/mock/mock_auth_service.py | 26 +- .../remote/mock/mock_fileshare_service.py | 25 +- .../remote/mock/mock_network_service.py | 35 +- .../remote/mock/mock_remote_exec_service.py | 26 +- .../services/remote/mock/mock_vm_service.py | 55 +-- .../tests/services/remote/ssh/__init__.py | 14 +- .../tests/services/remote/ssh/fixtures.py | 63 +-- .../services/remote/ssh/test_ssh_fileshare.py | 43 +- .../remote/ssh/test_ssh_host_service.py | 94 ++--- .../services/remote/ssh/test_ssh_service.py | 53 +-- .../mlos_bench/tests/storage/conftest.py | 4 +- .../mlos_bench/tests/storage/exp_data_test.py | 67 ++- .../mlos_bench/tests/storage/exp_load_test.py | 62 ++- .../mlos_bench/tests/storage/sql/fixtures.py | 85 ++-- .../tests/storage/trial_config_test.py | 10 +- .../tests/storage/trial_schedule_test.py | 22 +- .../tests/storage/trial_telemetry_test.py | 41 +- .../tests/storage/tunable_config_data_test.py | 21 +- .../tunable_config_trial_group_data_test.py | 38 +- .../mlos_bench/tests/test_with_alt_tz.py | 6 +- .../tests/tunable_groups_fixtures.py | 38 +- .../mlos_bench/tests/tunables/conftest.py | 47 +-- .../tunables/test_tunable_categoricals.py | 2 +- .../tunables/test_tunables_size_props.py | 27 +- .../tests/tunables/tunable_comparison_test.py | 15 +- .../tests/tunables/tunable_definition_test.py | 98 ++--- .../tunables/tunable_distributions_test.py | 68 +-- .../tunables/tunable_group_indexing_test.py | 4 +- .../tunables/tunable_group_subgroup_test.py | 2 +- .../tunable_to_configspace_distr_test.py | 54 ++- .../tunables/tunable_to_configspace_test.py | 59 +-- .../tests/tunables/tunables_assign_test.py | 26 +- .../tests/tunables/tunables_str_test.py | 76 ++-- mlos_bench/mlos_bench/tunables/__init__.py | 6 +- .../mlos_bench/tunables/covariant_group.py | 18 +- mlos_bench/mlos_bench/tunables/tunable.py | 57 +-- .../mlos_bench/tunables/tunable_groups.py | 58 +-- mlos_bench/mlos_bench/util.py | 42 +- mlos_bench/mlos_bench/version.py | 2 +- mlos_bench/setup.py | 81 ++-- mlos_core/mlos_core/optimizers/__init__.py | 32 +- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 14 +- .../bayesian_optimizers/smac_optimizer.py | 134 +++--- .../mlos_core/optimizers/flaml_optimizer.py | 57 +-- mlos_core/mlos_core/optimizers/optimizer.py | 117 ++---- .../mlos_core/optimizers/random_optimizer.py | 30 +- .../mlos_core/spaces/adapters/__init__.py | 19 +- .../mlos_core/spaces/adapters/adapter.py | 6 +- .../mlos_core/spaces/adapters/llamatune.py | 164 +++----- .../mlos_core/spaces/converters/flaml.py | 18 +- mlos_core/mlos_core/tests/__init__.py | 19 +- .../optimizers/bayesian_optimizers_test.py | 23 +- .../mlos_core/tests/optimizers/conftest.py | 6 +- .../tests/optimizers/one_hot_test.py | 77 ++-- .../optimizers/optimizer_multiobj_test.py | 78 ++-- .../tests/optimizers/optimizer_test.py | 213 ++++------ .../spaces/adapters/identity_adapter_test.py | 25 +- .../tests/spaces/adapters/llamatune_test.py | 393 +++++++----------- .../adapters/space_adapter_factory_test.py | 56 +-- .../mlos_core/tests/spaces/spaces_test.py | 39 +- mlos_core/mlos_core/util.py | 10 +- mlos_core/mlos_core/version.py | 2 +- mlos_core/setup.py | 57 ++- mlos_viz/mlos_viz/__init__.py | 19 +- mlos_viz/mlos_viz/base.py | 208 ++++----- mlos_viz/mlos_viz/dabl.py | 62 +-- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 7 +- mlos_viz/mlos_viz/version.py | 2 +- mlos_viz/setup.py | 47 +-- 210 files changed, 4082 insertions(+), 6260 deletions(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index 2c6da8cc6a..c32dea9bf6 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -20,7 +20,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: Flatten every dict in the hierarchy and rename the keys with the dict path. """ if isinstance(data, dict): - for key, val in data.items(): + for (key, val) in data.items(): yield from _flat_dict(val, f"{path}.{key}") else: yield (path, data) @@ -30,15 +30,13 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: """ Convert FIO read data from JSON to tall CSV. """ - with open(input_file, mode="r", encoding="utf-8") as fh_input: + with open(input_file, mode='r', encoding='utf-8') as fh_input: json_data = json.load(fh_input) - data = list( - itertools.chain( - _flat_dict(json_data["jobs"][0], prefix), - _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"), - ) - ) + data = list(itertools.chain( + _flat_dict(json_data["jobs"][0], prefix), + _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util") + )) tall_df = pandas.DataFrame(data, columns=["metric", "value"]) tall_df.to_csv(output_file, index=False) @@ -51,12 +49,12 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.") parser.add_argument( - "input", help="FIO benchmark results in JSON format (downloaded from a remote VM)." - ) + "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).") parser.add_argument( - "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench)." - ) - parser.add_argument("--prefix", default="fio", help="Prefix of the metric IDs (default 'fio')") + "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).") + parser.add_argument( + "--prefix", default="fio", + help="Prefix of the metric IDs (default 'fio')") args = parser.parse_args() _main(args.input, args.output, args.prefix) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index d41f20d2a9..949b9f9d91 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -14,19 +14,17 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( - fname_output, "wt", encoding="utf-8", newline="" - ) as fh_config: - for key, val in json.load(fh_tunables).items(): - line = f"{key} {val}" + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ + open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: + for (key, val) in json.load(fh_tunables).items(): + line = f'{key} {val}' fh_config.write(line + "\n") print(line) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate Redis config from tunable parameters JSON." - ) + description="generate Redis config from tunable parameters JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output Redis config file.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index 81a2b673a4..e33c717953 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -21,19 +21,18 @@ def _main(input_file: str, output_file: str) -> None: # Format the results from wide to long # The target is columns of metric and value to act as key-value pairs. df_long = ( - df_wide.melt(id_vars=["test"]) + df_wide + .melt(id_vars=["test"]) .assign(metric=lambda df: df["test"] + "_" + df["variable"]) .drop(columns=["test", "variable"]) .loc[:, ["metric", "value"]] ) # Add a default `score` metric to the end of the dataframe. - df_long = pd.concat( - [ - df_long, - pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}), - ] - ) + df_long = pd.concat([ + df_long, + pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}) + ]) df_long.to_csv(output_file, index=False) print(f"Converted: {input_file} -> {output_file}") @@ -43,9 +42,7 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") - parser.add_argument( - "output", - help="Converted Redis benchmark data" + " (to be consumed by OS Autotune framework).", - ) + parser.add_argument("output", help="Converted Redis benchmark data" + + " (to be consumed by OS Autotune framework).") args = parser.parse_args() _main(args.input, args.output) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 40a05e1511..41bd162459 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -14,10 +14,8 @@ JSON_CONFIG_FILE = "config-boot-time.json" NEW_CFG = "zz-mlos-boot-params.cfg" -with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open( - NEW_CFG, "w", encoding="UTF-8" -) as fh_config: +with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \ + open(NEW_CFG, 'w', encoding='UTF-8') as fh_config: for key, val in json.load(fh_json).items(): - fh_config.write( - 'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n' - ) + fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$' + f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n') diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index 9f130e5c0e..de344d61fb 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -14,10 +14,9 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( - fname_output, "wt", encoding="utf-8", newline="" - ) as fh_config: - for key, val in json.load(fh_tunables).items(): + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ + open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: + for (key, val) in json.load(fh_tunables).items(): line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"' fh_config.write(line + "\n") print(line) @@ -25,8 +24,7 @@ def _main(fname_input: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate GRUB config from tunable parameters JSON." - ) + description="Generate GRUB config from tunable parameters JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output shell script to configure GRUB.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index e632495061..85a49a1817 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -22,7 +22,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: tunables_meta = json.load(fh_meta) with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for key, val in tunables_data.items(): + for (key, val) in tunables_data.items(): meta = tunables_meta.get(key, {}) name_prefix = meta.get("name_prefix", "") line = f'echo "{val}" > {name_prefix}{key}' @@ -33,8 +33,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate a script to update kernel parameters from tunables JSON." - ) + description="generate a script to update kernel parameters from tunables JSON.") parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("meta", help="JSON file with tunable parameters metadata.") diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index 672a215aad..fa3b63e2e6 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -9,6 +9,6 @@ from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ - "ConfigSchema", - "CONFIG_SCHEMA_DIR", + 'ConfigSchema', + 'CONFIG_SCHEMA_DIR', ] diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 181f96e5d6..82cbcacce2 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -27,14 +27,9 @@ # It is used in `ConfigSchema.validate()` method below. # NOTE: this may cause pytest to fail if it's expecting exceptions # to be raised for invalid configs. -_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION" -_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in { - "true", - "y", - "yes", - "on", - "1", -} +_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION' +_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower() + in {'true', 'y', 'yes', 'on', '1'}) # Note: we separate out the SchemaStore from a class method on ConfigSchema @@ -85,12 +80,10 @@ def _load_registry(cls) -> None: """Also store them in a Registry object for referencing by recent versions of jsonschema.""" if not cls._SCHEMA_STORE: cls._load_schemas() - cls._REGISTRY = Registry().with_resources( - [ - (url, Resource.from_contents(schema, default_specification=DRAFT202012)) - for url, schema in cls._SCHEMA_STORE.items() - ] - ) + cls._REGISTRY = Registry().with_resources([ + (url, Resource.from_contents(schema, default_specification=DRAFT202012)) + for url, schema in cls._SCHEMA_STORE.items() + ]) @property def registry(self) -> Registry: diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 3c14e63598..4ccef7817b 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -13,7 +13,7 @@ from mlos_bench.os_environ import environ -class DictTemplater: # pylint: disable=too-few-public-methods +class DictTemplater: # pylint: disable=too-few-public-methods """ Simple class to help with nested dictionary $var templating. """ @@ -32,9 +32,9 @@ def __init__(self, source_dict: Dict[str, Any]): # The source/target dictionary to expand. self._dict: Dict[str, Any] = {} - def expand_vars( - self, *, extra_source_dict: Optional[Dict[str, Any]] = None, use_os_env: bool = False - ) -> Dict[str, Any]: + def expand_vars(self, *, + extra_source_dict: Optional[Dict[str, Any]] = None, + use_os_env: bool = False) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. @@ -55,9 +55,7 @@ def expand_vars( assert isinstance(self._dict, dict) return self._dict - def _expand_vars( - self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool - ) -> Any: + def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: """ Recursively expand $var strings in the currently operating dictionary. """ @@ -73,7 +71,7 @@ def _expand_vars( elif isinstance(value, dict): # Note: we use a loop instead of dict comprehension in order to # allow secondary expansion of subsequent values immediately. - for key, val in value.items(): + for (key, val) in value.items(): value[key] = self._expand_vars(val, extra_source_dict, use_os_env) elif isinstance(value, list): value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index 629e7d9c5f..a1ccadae5f 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -15,11 +15,12 @@ from mlos_bench.environments.status import Status __all__ = [ - "Status", - "Environment", - "MockEnv", - "RemoteEnv", - "LocalEnv", - "LocalFileShareEnv", - "CompositeEnv", + 'Status', + + 'Environment', + 'MockEnv', + 'RemoteEnv', + 'LocalEnv', + 'LocalFileShareEnv', + 'CompositeEnv', ] diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index f1ec25823c..61fbd69f50 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -48,16 +48,15 @@ class Environment(metaclass=abc.ABCMeta): """ @classmethod - def new( - cls, - *, - env_name: str, - class_name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ) -> "Environment": + def new(cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -95,18 +94,16 @@ def new( config=config, global_config=global_config, tunables=tunables, - service=service, + service=service ) - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -137,29 +134,24 @@ def __init__( self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Environment: '%s' Service: %s", - name, - self._service.pprint() if self._service else None, - ) + _LOG.debug("Environment: '%s' Service: %s", name, + self._service.pprint() if self._service else None) if tunables is None: - _LOG.warning( - "No tunables provided for %s. Tunable inheritance across composite environments may be broken.", - name, - ) + _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) tunables = TunableGroups() groups = self._expand_groups( - config.get("tunable_params", []), (global_config or {}).get("tunable_params_map", {}) - ) + config.get("tunable_params", []), + (global_config or {}).get("tunable_params_map", {})) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) self._tunable_params = tunables.subgroup(groups) # If a parameter comes from the tunables, do not require it in the const_args or globals - req_args = set(config.get("required_args", [])) - set( - self._tunable_params.get_param_values().keys() + req_args = ( + set(config.get("required_args", [])) - + set(self._tunable_params.get_param_values().keys()) ) merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) self._const_args = self._expand_vars(self._const_args, global_config or {}) @@ -168,7 +160,8 @@ def __init__( _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) + _LOG.debug("Config for: '%s'\n%s", + name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: """ @@ -186,9 +179,8 @@ def _validate_json_config(self, config: dict, name: str) -> None: ConfigSchema.ENVIRONMENT.validate(json_config) @staticmethod - def _expand_groups( - groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] - ) -> List[str]: + def _expand_groups(groups: Iterable[str], + groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -210,9 +202,7 @@ def _expand_groups( if grp[:1] == "$": tunable_group_name = grp[1:] if tunable_group_name not in groups_exp: - raise KeyError( - f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}" - ) + raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") add_groups = groups_exp[tunable_group_name] res += [add_groups] if isinstance(add_groups, str) else add_groups else: @@ -220,9 +210,7 @@ def _expand_groups( return res @staticmethod - def _expand_vars( - params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] - ) -> dict: + def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: """ Expand `$var` into actual values of the variables. """ @@ -233,7 +221,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": assert self._service is not None return self._service.config_loader_service - def __enter__(self) -> "Environment": + def __enter__(self) -> 'Environment': """ Enter the environment's benchmarking context. """ @@ -244,12 +232,9 @@ def __enter__(self) -> "Environment": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -319,8 +304,7 @@ def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ return tunables.get_param_values( group_names=list(self._tunable_params.get_covariant_group_names()), - into_params=self._const_args.copy(), - ) + into_params=self._const_args.copy()) @property def tunable_params(self) -> TunableGroups: @@ -380,15 +364,10 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - # (Derived classes still have to check `self._tunable_params.is_updated()`). is_updated = self._tunable_params.is_updated() if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Env '%s': Tunable groups reset = %s :: %s", - self, - is_updated, - { - name: self._tunable_params.is_updated([name]) - for name in self._tunable_params.get_covariant_group_names() - }, - ) + _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { + name: self._tunable_params.is_updated([name]) + for name in self._tunable_params.get_covariant_group_names() + }) else: _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 36ab99a223..a71b8ab9be 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -27,15 +27,13 @@ class CompositeEnv(Environment): Composite benchmark environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment with a given config. @@ -55,13 +53,8 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) # By default, the Environment includes only the tunables explicitly specified # in the "tunable_params" section of the config. `CompositeEnv`, however, must @@ -77,19 +70,17 @@ def __init__( # each CompositeEnv gets a copy of the original global config and adjusts it with # the `const_args` specific to it. global_config = (global_config or {}).copy() - for key, val in self._const_args.items(): + for (key, val) in self._const_args.items(): global_config.setdefault(key, val) for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, tunables, global_config, self._const_args, self._service - ): + child_config_file, tunables, global_config, self._const_args, self._service): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service - ) + child_config, tunables, global_config, self._const_args, self._service) self._add_child(env, tunables) _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) @@ -101,12 +92,9 @@ def __enter__(self) -> Environment: self._child_contexts = [env.__enter__() for env in self._children] return super().__enter__() - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: ex_throw = None for env in reversed(self._children): try: @@ -144,11 +132,8 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: pretty : str Pretty-printed environment configuration. """ - return ( - super().pprint(indent, level) - + "\n" - + "\n".join(child.pprint(indent, level + 1) for child in self._children) - ) + return super().pprint(indent, level) + '\n' + '\n'.join( + child.pprint(indent, level + 1) for child in self._children) def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ @@ -180,8 +165,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - """ assert self._in_context self._is_ready = super().setup(tunables, global_config) and all( - env_context.setup(tunables, global_config) for env_context in self._child_contexts - ) + env_context.setup(tunables, global_config) for env_context in self._child_contexts) return self._is_ready def teardown(self) -> None: diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index a99eefea19..0cdd8349b4 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv __all__ = [ - "LocalEnv", - "LocalFileShareEnv", + 'LocalEnv', + 'LocalFileShareEnv', ] diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index 72616f7cd3..da20f5c961 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -36,15 +36,13 @@ class LocalEnv(ScriptEnv): Scheduler-side Environment that runs scripts locally. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for local execution. @@ -67,17 +65,11 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsLocalExec - ), "LocalEnv requires a service that supports local execution" + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service self._temp_dir: Optional[str] = None @@ -91,18 +83,13 @@ def __init__( def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None - self._temp_dir_context = self._local_exec_service.temp_dir_context( - self.config.get("temp_dir") - ) + self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir")) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the benchmarking environment. """ @@ -150,14 +137,10 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - fname = path_join(self._temp_dir, self._dump_meta_file) _LOG.debug("Dump tunables metadata to file: %s", fname) with open(fname, "wt", encoding="utf-8") as fh_meta: - json.dump( - { - tunable.name: tunable.meta - for (tunable, _group) in self._tunable_params - if tunable.meta - }, - fh_meta, - ) + json.dump({ + tunable.name: tunable.meta + for (tunable, _group) in self._tunable_params if tunable.meta + }, fh_meta) if self._script_setup: (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir) @@ -197,24 +180,18 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: _LOG.debug("Not reading the data at: %s", self) return (Status.SUCCEEDED, timestamp, stdout_data) - data = self._normalize_columns( - pandas.read_csv( - self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir] - ), - index_col=False, - ) - ) + data = self._normalize_columns(pandas.read_csv( + self._config_loader_service.resolve_path( + self._read_results_file, extra_paths=[self._temp_dir]), + index_col=False, + )) _LOG.debug("Read data:\n%s", data) if list(data.columns) == ["metric", "value"]: - _LOG.info( - "Local results have (metric,value) header and %d rows: assume long format", - len(data), - ) + _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data)) data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) # Try to convert string metrics to numbers. - data = data.apply(pandas.to_numeric, errors="coerce").fillna(data) # type: ignore[assignment] # (false positive) + data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive) elif len(data) == 1: _LOG.info("Local results have 1 row: assume wide format") else: @@ -232,8 +209,8 @@ def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. - if sys.platform == "win32": - data.rename(str.rstrip, axis="columns", inplace=True) + if sys.platform == 'win32': + data.rename(str.rstrip, axis='columns', inplace=True) return data def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: @@ -245,23 +222,24 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir] - ) + self._read_telemetry_file, extra_paths=[self._temp_dir]) # TODO: Use the timestamp of the CSV file as our status timestamp? # FIXME: We should not be assuming that the only output file type is a CSV. - data = self._normalize_columns(pandas.read_csv(fname, index_col=False)) + data = self._normalize_columns( + pandas.read_csv(fname, index_col=False)) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") expected_col_names = ["timestamp", "metric", "value"] if len(data.columns) != len(expected_col_names): - raise ValueError(f"Telemetry data must have columns {expected_col_names}") + raise ValueError(f'Telemetry data must have columns {expected_col_names}') if list(data.columns) != expected_col_names: # Assume no header - this is ok for telemetry data. - data = pandas.read_csv(fname, index_col=False, names=expected_col_names) + data = pandas.read_csv( + fname, index_col=False, names=expected_col_names) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") except FileNotFoundError as ex: @@ -270,14 +248,10 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: _LOG.debug("Read telemetry data:\n%s", data) col_dtypes: Mapping[int, Type] = {0: datetime} - return ( - status, - timestamp, - [ - (pandas.Timestamp(ts).to_pydatetime(), metric, value) - for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) - ], - ) + return (status, timestamp, [ + (pandas.Timestamp(ts).to_pydatetime(), metric, value) + for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) + ]) def teardown(self) -> None: """ @@ -309,8 +283,7 @@ def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd - ) + script, env=env_params, cwd=cwd) if return_code != 0: _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) return (return_code, {"stdout": stdout, "stderr": stderr}) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 7a6862ab2c..174afd387c 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -29,15 +29,13 @@ class LocalFileShareEnv(LocalEnv): and uploads/downloads data to the shared file storage. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new application environment with a given config. @@ -61,22 +59,14 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - assert self._service is not None and isinstance( - self._service, SupportsLocalExec - ), "LocalEnv requires a service that supports local execution" + assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ + "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service - assert self._service is not None and isinstance( - self._service, SupportsFileShareOps - ), "LocalEnv requires a service that supports file upload/download operations" + assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ + "LocalEnv requires a service that supports file upload/download operations" self._file_share_service: SupportsFileShareOps = self._service self._upload = self._template_from_to("upload") @@ -87,12 +77,14 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: Convert a list of {"from": "...", "to": "..."} to a list of pairs of string.Template objects so that we can plug in self._params into it later. """ - return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])] + return [ + (Template(d['from']), Template(d['to'])) + for d in self.config.get(config_key, []) + ] @staticmethod - def _expand( - from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] - ) -> Generator[Tuple[str, str], None, None]: + def _expand(from_to: Iterable[Tuple[Template, Template]], + params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. Return a generator of (str, str) pairs of paths. @@ -127,14 +119,9 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for path_from, path_to in self._expand(self._upload, params): - self._file_share_service.upload( - self._params, - self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir] - ), - path_to, - ) + for (path_from, path_to) in self._expand(self._upload, params): + self._file_share_service.upload(self._params, self._config_loader_service.resolve_path( + path_from, extra_paths=[self._temp_dir]), path_to) return self._is_ready def _download_files(self, ignore_missing: bool = False) -> None: @@ -150,15 +137,11 @@ def _download_files(self, ignore_missing: bool = False) -> None: assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for path_from, path_to in self._expand(self._download, params): + for (path_from, path_to) in self._expand(self._download, params): try: - self._file_share_service.download( - self._params, - path_from, - self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir] - ), - ) + self._file_share_service.download(self._params, + path_from, self._config_loader_service.resolve_path( + path_to, extra_paths=[self._temp_dir])) except FileNotFoundError as ex: _LOG.warning("Cannot download: %s", path_from) if not ignore_missing: diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index 2f4d4b0ab4..cc47b95500 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -29,15 +29,13 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment that produces mock benchmark data. @@ -57,13 +55,8 @@ def __init__( service: Service An optional service object. Not used by this class. """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) seed = int(self.config.get("mock_env_seed", -1)) self._random = random.Random(seed or None) if seed >= 0 else None self._range = self.config.get("mock_env_range") @@ -88,9 +81,9 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return result # Simple convex function of all tunable parameters. - score = numpy.mean( - numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params]) - ) + score = numpy.mean(numpy.square([ + self._normalized(tunable) for (tunable, _group) in self._tunable_params + ])) # Add noise and shift the benchmark value from [0, 1] to a given range. noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 @@ -108,11 +101,11 @@ def _normalized(tunable: Tunable) -> float: """ val = None if tunable.is_categorical: - val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1) + val = (tunable.categories.index(tunable.category) / + float(len(tunable.categories) - 1)) elif tunable.is_numerical: - val = (tunable.numerical_value - tunable.range[0]) / float( - tunable.range[1] - tunable.range[0] - ) + val = ((tunable.numerical_value - tunable.range[0]) / + float(tunable.range[1] - tunable.range[0])) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index be18bff2fe..f07575ac86 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -14,10 +14,10 @@ from mlos_bench.environments.remote.vm_env import VMEnv __all__ = [ - "HostEnv", - "NetworkEnv", - "OSEnv", - "RemoteEnv", - "SaaSEnv", - "VMEnv", + 'HostEnv', + 'NetworkEnv', + 'OSEnv', + 'RemoteEnv', + 'SaaSEnv', + 'VMEnv', ] diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index 3b1abcd79a..05896c9e60 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -22,15 +22,13 @@ class HostEnv(Environment): Remote host environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for host operations. @@ -51,17 +49,10 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM/host, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - assert self._service is not None and isinstance( - self._service, SupportsHostProvisioning - ), "HostEnv requires a service that supports host provisioning operations" + assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \ + "HostEnv requires a service that supports host provisioning operations" self._host_service: SupportsHostProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index afa38229f5..552f1729d9 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -27,15 +27,13 @@ class NetworkEnv(Environment): but no real tuning is expected for it ... yet. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for network operations. @@ -56,21 +54,14 @@ def __init__( An optional service object (e.g., providing methods to deploy a network, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) # Virtual networks can be used for more than one experiment, so by default # we don't attempt to deprovision them. self._deprovision_on_teardown = config.get("deprovision_on_teardown", False) - assert self._service is not None and isinstance( - self._service, SupportsNetworkProvisioning - ), "NetworkEnv requires a service that supports network provisioning" + assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \ + "NetworkEnv requires a service that supports network provisioning" self._network_service: SupportsNetworkProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -114,9 +105,7 @@ def teardown(self) -> None: return # Else _LOG.info("Network tear down: %s", self) - (status, params) = self._network_service.deprovision_network( - self._params, ignore_errors=True - ) + (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True) if status.is_pending(): (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index 9fa2b5886a..ef733c77c2 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -24,15 +24,13 @@ class OSEnv(Environment): OS Level Environment for a host. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for remote execution. @@ -55,22 +53,14 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance( - self._service, SupportsOSOps - ), "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance(self._service, SupportsOSOps), \ + "RemoteEnv requires a service that supports host operations" self._os_service: SupportsOSOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 683405c6c5..cf38a57b01 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -32,15 +32,13 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for remote execution. @@ -63,25 +61,18 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a Host, VM, OS, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) self._wait_boot = self.config.get("wait_boot", False) - assert self._service is not None and isinstance( - self._service, SupportsRemoteExec - ), "RemoteEnv requires a service that supports remote execution operations" + assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ + "RemoteEnv requires a service that supports remote execution operations" self._remote_exec_service: SupportsRemoteExec = self._service if self._wait_boot: - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -179,8 +170,7 @@ def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optiona env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params - ) + script, config=self._params, env_params=env_params) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: (status, output) = self._remote_exec_service.get_remote_exec_results(output) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index 211db536d0..b661bfad7e 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -23,15 +23,13 @@ class SaaSEnv(Environment): Cloud-based (configurable) SaaS environment. """ - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for (configurable) cloud-based SaaS instance. @@ -52,22 +50,15 @@ def __init__( An optional service object (e.g., providing methods to configure the remote service). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) - - assert self._service is not None and isinstance( - self._service, SupportsHostOps - ), "RemoteEnv requires a service that supports host operations" + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) + + assert self._service is not None and isinstance(self._service, SupportsHostOps), \ + "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance( - self._service, SupportsRemoteConfig - ), "SaaSEnv requires a service that supports remote host configuration API" + assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \ + "SaaSEnv requires a service that supports remote host configuration API" self._config_service: SupportsRemoteConfig = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -93,8 +84,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values() - ) + self._params, self._tunable_params.get_param_values()) if not status.is_succeeded(): return False @@ -103,7 +93,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False # Azure Flex DB instances currently require a VM reboot after reconfiguration. - if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"): + if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'): _LOG.info("Restarting: %s", self) (status, params) = self._host_service.restart_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index d65d137459..129ac21a0f 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -27,15 +27,13 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") - def __init__( - self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ): + def __init__(self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None): """ Create a new environment for script execution. @@ -65,29 +63,19 @@ def __init__( An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__( - name=name, - config=config, - global_config=global_config, - tunables=tunables, - service=service, - ) + super().__init__(name=name, config=config, global_config=global_config, + tunables=tunables, service=service) self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) - self._shell_env_params_rename: Dict[str, str] = self.config.get( - "shell_env_params_rename", {} - ) + self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) results_stdout_pattern = self.config.get("results_stdout_pattern") - self._results_stdout_pattern: Optional[re.Pattern[str]] = ( - re.compile(results_stdout_pattern, flags=re.MULTILINE) - if results_stdout_pattern - else None - ) + self._results_stdout_pattern: Optional[re.Pattern[str]] = \ + re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: """ @@ -128,6 +116,4 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: if not self._results_stdout_pattern: return {} _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) - return { - key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout) - } + return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)} diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index f69893871f..4555ab7f50 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -20,7 +20,7 @@ else: from typing_extensions import TypeAlias -CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name +CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name if sys.version_info >= (3, 9): FutureReturnType: TypeAlias = Future[CoroReturnType] else: diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index d988e370b3..c8e48dab69 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -32,7 +32,7 @@ from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO -_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" +_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) _LOG = logging.getLogger(__name__) @@ -54,7 +54,8 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st For additional details, please see the website or the README.md files in the source tree: """ - parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) + parser = argparse.ArgumentParser(description=f"{description} : {long_text}", + epilog=epilog) (args, args_rest) = self._parse_args(parser, argv) # Bootstrap config loader: command line takes priority. @@ -95,13 +96,13 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: - self.global_config["experiment_id"] = args.experiment_id + self.global_config['experiment_id'] = args.experiment_id # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. - self.global_config.setdefault("trial_id", 1) + self.global_config.setdefault('trial_id', 1) self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) assert isinstance(self.global_config, dict) @@ -109,29 +110,24 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # --service cli args should override the config file values. service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) - self._parent_service = self._parent_service.load_services( - service_files, self.global_config, self._parent_service - ) + self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) env_path = args.environment or config.get("environment") if not env_path: _LOG.error("No environment config specified.") - parser.error( - "At least the Environment config must be specified." - + " Run `mlos_bench --help` and consult `README.md` for more info." - ) + parser.error("At least the Environment config must be specified." + + " Run `mlos_bench --help` and consult `README.md` for more info.") self.root_env_config = self._config_loader.resolve_path(env_path) self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service - ) + self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) _LOG.info("Init environment: %s", self.environment) # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer self.tunables = self._init_tunable_values( args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, - config.get("tunable_values", []) + (args.tunable_values or []), + config.get("tunable_values", []) + (args.tunable_values or []) ) _LOG.info("Init tunables: %s", self.tunables) @@ -141,11 +137,7 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st self.storage = self._load_storage(args.storage or config.get("storage")) _LOG.info("Init storage: %s", self.storage) - self.teardown: bool = ( - bool(args.teardown) - if args.teardown is not None - else bool(config.get("teardown", True)) - ) + self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) _LOG.info("Init scheduler: %s", self.scheduler) @@ -164,146 +156,87 @@ def service(self) -> Service: return self._parent_service @staticmethod - def _parse_args( - parser: argparse.ArgumentParser, argv: Optional[List[str]] - ) -> Tuple[argparse.Namespace, List[str]]: + def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: """ Parse the command line arguments. """ parser.add_argument( - "--config", - required=False, - help="Main JSON5 configuration file. Its keys are the same as the" - + " command line options and can be overridden by the latter.\n" - + "\n" - + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " - + " for additional config examples for this and other arguments.", - ) + '--config', required=False, + help='Main JSON5 configuration file. Its keys are the same as the' + + ' command line options and can be overridden by the latter.\n' + + '\n' + + ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + + ' for additional config examples for this and other arguments.') parser.add_argument( - "--log_file", - "--log-file", - required=False, - help="Path to the log file. Use stdout if omitted.", - ) + '--log_file', '--log-file', required=False, + help='Path to the log file. Use stdout if omitted.') parser.add_argument( - "--log_level", - "--log-level", - required=False, - type=str, - help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}." - + " Set to DEBUG for debug, WARNING for warnings only.", - ) + '--log_level', '--log-level', required=False, type=str, + help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + + ' Set to DEBUG for debug, WARNING for warnings only.') parser.add_argument( - "--config_path", - "--config-path", - "--config-paths", - "--config_paths", - nargs="+", - action="extend", - required=False, - help="One or more locations of JSON config files.", - ) + '--config_path', '--config-path', '--config-paths', '--config_paths', + nargs="+", action='extend', required=False, + help='One or more locations of JSON config files.') parser.add_argument( - "--service", - "--services", - nargs="+", - action="extend", - required=False, - help="Path to JSON file with the configuration of the service(s) for environment(s) to use.", - ) + '--service', '--services', + nargs='+', action='extend', required=False, + help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') parser.add_argument( - "--environment", - required=False, - help="Path to JSON file with the configuration of the benchmarking environment(s).", - ) + '--environment', required=False, + help='Path to JSON file with the configuration of the benchmarking environment(s).') parser.add_argument( - "--optimizer", - required=False, - help="Path to the optimizer configuration file. If omitted, run" - + " a single trial with default (or specified in --tunable_values).", - ) + '--optimizer', required=False, + help='Path to the optimizer configuration file. If omitted, run' + + ' a single trial with default (or specified in --tunable_values).') parser.add_argument( - "--trial_config_repeat_count", - "--trial-config-repeat-count", - required=False, - type=int, - help="Number of times to repeat each config. Default is 1 trial per config, though more may be advised.", - ) + '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, + help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') parser.add_argument( - "--scheduler", - required=False, - help="Path to the scheduler configuration file. By default, use" - + " a single worker synchronous scheduler.", - ) + '--scheduler', required=False, + help='Path to the scheduler configuration file. By default, use' + + ' a single worker synchronous scheduler.') parser.add_argument( - "--storage", - required=False, - help="Path to the storage configuration file." - + " If omitted, use the ephemeral in-memory SQL storage.", - ) + '--storage', required=False, + help='Path to the storage configuration file.' + + ' If omitted, use the ephemeral in-memory SQL storage.') parser.add_argument( - "--random_init", - "--random-init", - required=False, - default=False, - dest="random_init", - action="store_true", - help="Initialize tunables with random values. (Before applying --tunable_values).", - ) + '--random_init', '--random-init', required=False, default=False, + dest='random_init', action='store_true', + help='Initialize tunables with random values. (Before applying --tunable_values).') parser.add_argument( - "--random_seed", - "--random-seed", - required=False, - type=int, - help="Seed to use with --random_init", - ) + '--random_seed', '--random-seed', required=False, type=int, + help='Seed to use with --random_init') parser.add_argument( - "--tunable_values", - "--tunable-values", - nargs="+", - action="extend", - required=False, - help="Path to one or more JSON files that contain values of the tunable" - + " parameters. This can be used for a single trial (when no --optimizer" - + " is specified) or as default values for the first run in optimization.", - ) + '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, + help='Path to one or more JSON files that contain values of the tunable' + + ' parameters. This can be used for a single trial (when no --optimizer' + + ' is specified) or as default values for the first run in optimization.') parser.add_argument( - "--globals", - nargs="+", - action="extend", - required=False, - help="Path to one or more JSON files that contain additional" - + " [private] parameters of the benchmarking environment.", - ) + '--globals', nargs="+", action='extend', required=False, + help='Path to one or more JSON files that contain additional' + + ' [private] parameters of the benchmarking environment.') parser.add_argument( - "--no_teardown", - "--no-teardown", - required=False, - default=None, - dest="teardown", - action="store_false", - help="Disable teardown of the environment after the benchmark.", - ) + '--no_teardown', '--no-teardown', required=False, default=None, + dest='teardown', action='store_false', + help='Disable teardown of the environment after the benchmark.') parser.add_argument( - "--experiment_id", - "--experiment-id", - required=False, - default=None, + '--experiment_id', '--experiment-id', required=False, default=None, help=""" Experiment ID to use for the benchmark. If omitted, the value from the --cli config or --globals is used. @@ -313,7 +246,7 @@ def _parse_args( changes are made to config files, scripts, versions, etc. This is left as a manual operation as detection of what is "incompatible" is not easily automatable across systems. - """, + """ ) # By default we use the command line arguments, but allow the caller to @@ -355,18 +288,16 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: _LOG.debug("Parsed config: %s", config) return config - def _load_config( - self, - args_globals: Iterable[str], - config_path: Iterable[str], - args_rest: Iterable[str], - global_config: Dict[str, Any], - ) -> Dict[str, Any]: + def _load_config(self, + args_globals: Iterable[str], + config_path: Iterable[str], + args_rest: Iterable[str], + global_config: Dict[str, Any]) -> Dict[str, Any]: """ Get key/value pairs of the global configuration parameters from the specified config files (if any) and command line arguments. """ - for config_file in args_globals or []: + for config_file in (args_globals or []): conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) assert isinstance(conf, dict) global_config.update(conf) @@ -375,9 +306,8 @@ def _load_config( global_config["config_path"] = config_path return global_config - def _init_tunable_values( - self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] - ) -> TunableGroups: + def _init_tunable_values(self, random_init: bool, seed: Optional[int], + args_tunables: Optional[str]) -> TunableGroups: """ Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. @@ -387,10 +317,8 @@ def _init_tunable_values( if random_init: tunables = MockOptimizer( - tunables=tunables, - service=None, - config={"start_with_defaults": False, "seed": seed}, - ).suggest() + tunables=tunables, service=None, + config={"start_with_defaults": False, "seed": seed}).suggest() _LOG.debug("Init tunables: random = %s", tunables) if args_tunables is not None: @@ -411,20 +339,15 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: if args_optimizer is None: # global_config may contain additional properties, so we need to # strip those out before instantiating the basic oneshot optimizer. - config = { - key: val - for key, val in self.global_config.items() - if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS - } - return OneShotOptimizer(self.tunables, config=config, service=self._parent_service) + config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} + return OneShotOptimizer( + self.tunables, config=config, service=self._parent_service) class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) assert isinstance(class_config, Dict) - optimizer = self._config_loader.build_optimizer( - tunables=self.tunables, - service=self._parent_service, - config=class_config, - global_config=self.global_config, - ) + optimizer = self._config_loader.build_optimizer(tunables=self.tunables, + service=self._parent_service, + config=class_config, + global_config=self.global_config) return optimizer def _load_storage(self, args_storage: Optional[str]) -> Storage: @@ -436,20 +359,17 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: if args_storage is None: # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.storage import SqlStorage - - return SqlStorage( - service=self._parent_service, - config={ - "drivername": "sqlite", - "database": ":memory:", - "lazy_schema_create": True, - }, - ) + return SqlStorage(service=self._parent_service, + config={ + "drivername": "sqlite", + "database": ":memory:", + "lazy_schema_create": True, + }) class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) assert isinstance(class_config, Dict) - storage = self._config_loader.build_storage( - service=self._parent_service, config=class_config, global_config=self.global_config - ) + storage = self._config_loader.build_storage(service=self._parent_service, + config=class_config, + global_config=self.global_config) return storage def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: @@ -464,7 +384,6 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: if args_scheduler is None: # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.sync_scheduler import SyncScheduler - return SyncScheduler( # All config values can be overridden from global config config={ diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index a61b55d440..f10fa3c82e 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -12,8 +12,8 @@ from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer __all__ = [ - "Optimizer", - "MockOptimizer", - "OneShotOptimizer", - "MlosCoreOptimizer", + 'Optimizer', + 'MockOptimizer', + 'OneShotOptimizer', + 'MlosCoreOptimizer', ] diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index f719c236e5..b9df1db1b7 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -26,7 +26,7 @@ _LOG = logging.getLogger(__name__) -class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes +class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ An abstract interface between the benchmarking framework and mlos_core optimizers. """ @@ -39,13 +39,11 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attribu "start_with_defaults", } - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): """ Create a new optimizer for the given configuration space defined by the tunables. @@ -69,20 +67,19 @@ def __init__( self._seed = int(config.get("seed", 42)) self._in_context = False - experiment_id = self._global_config.get("experiment_id") + experiment_id = self._global_config.get('experiment_id') self.experiment_id = str(experiment_id).strip() if experiment_id else None self._iter = 0 # If False, use the optimizer to suggest the initial configuration; # if True (default), use the already initialized values for the first iteration. self._start_with_defaults: bool = bool( - strtobool(str(self._config.pop("start_with_defaults", True))) - ) - self._max_iter = int(self._config.pop("max_suggestions", 100)) + strtobool(str(self._config.pop('start_with_defaults', True)))) + self._max_iter = int(self._config.pop('max_suggestions', 100)) - opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"}) + opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'}) self._opt_targets: Dict[str, Literal[1, -1]] = {} - for opt_target, opt_dir in opt_targets.items(): + for (opt_target, opt_dir) in opt_targets.items(): if opt_dir == "min": self._opt_targets[opt_target] = 1 elif opt_dir == "max": @@ -110,7 +107,7 @@ def __repr__(self) -> str: ) return f"{self.name}({opt_targets},config={self._config})" - def __enter__(self) -> "Optimizer": + def __enter__(self) -> 'Optimizer': """ Enter the optimizer's context. """ @@ -119,12 +116,9 @@ def __enter__(self) -> "Optimizer": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the optimizer. """ @@ -209,7 +203,7 @@ def name(self) -> str: return self.__class__.__name__ @property - def targets(self) -> Dict[str, Literal["min", "max"]]: + def targets(self) -> Dict[str, Literal['min', 'max']]: """ A dictionary of {target: direction} of optimization targets. """ @@ -226,12 +220,10 @@ def supports_preload(self) -> bool: return True @abstractmethod - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. @@ -249,12 +241,8 @@ def bulk_register( is_not_empty : bool True if there is data to register, false otherwise. """ - _LOG.info( - "Update the optimizer with: %d configs, %d scores, %d status values", - len(configs or []), - len(scores or []), - len(status or []), - ) + _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values", + len(configs or []), len(scores or []), len(status or [])) if len(configs or []) != len(scores or []): raise ValueError("Numbers of configs and scores do not match.") if status is not None and len(configs or []) != len(status or []): @@ -283,12 +271,8 @@ def suggest(self) -> TunableGroups: return self._tunables.copy() @abstractmethod - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: """ Register the observation for the given configuration. @@ -309,16 +293,15 @@ def register( Benchmark scores extracted (and possibly transformed) from the dataframe that's being MINIMIZED. """ - _LOG.info( - "Iteration %d :: Register: %s = %s score: %s", self._iter, tunables, status, score - ) + _LOG.info("Iteration %d :: Register: %s = %s score: %s", + self._iter, tunables, status, score) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) - def _get_scores( - self, status: Status, scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] - ) -> Optional[Dict[str, float]]: + def _get_scores(self, status: Status, + scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] + ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -347,7 +330,7 @@ def _get_scores( assert scores is not None target_metrics: Dict[str, float] = {} - for opt_target, opt_dir in self._opt_targets.items(): + for (opt_target, opt_dir) in self._opt_targets.items(): val = scores[opt_target] assert val is not None target_metrics[opt_target] = float(val) * opt_dir @@ -362,9 +345,7 @@ def not_converged(self) -> bool: return self._iter < self._max_iter @abstractmethod - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index a98edb463b..62341c613d 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -48,8 +48,7 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 -) -> ConfigurationSpace: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, wrapped in a ConfigurationSpace for composability. @@ -72,28 +71,27 @@ def _tunable_to_configspace( meta = {"group": group_name, "cost": cost} # {"scaling": ""} if tunable.type == "categorical": - return ConfigurationSpace( - { - tunable.name: CategoricalHyperparameter( - name=tunable.name, - choices=tunable.categories, - weights=_normalize_weights(tunable.weights) if tunable.weights else None, - default_value=tunable.default, - meta=meta, - ) - } - ) + return ConfigurationSpace({ + tunable.name: CategoricalHyperparameter( + name=tunable.name, + choices=tunable.categories, + weights=_normalize_weights(tunable.weights) if tunable.weights else None, + default_value=tunable.default, + meta=meta) + }) distribution: Union[Uniform, Normal, Beta, None] = None if tunable.distribution == "uniform": distribution = Uniform() elif tunable.distribution == "normal": distribution = Normal( - mu=tunable.distribution_params["mu"], sigma=tunable.distribution_params["sigma"] + mu=tunable.distribution_params["mu"], + sigma=tunable.distribution_params["sigma"] ) elif tunable.distribution == "beta": distribution = Beta( - alpha=tunable.distribution_params["alpha"], beta=tunable.distribution_params["beta"] + alpha=tunable.distribution_params["alpha"], + beta=tunable.distribution_params["beta"] ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -105,26 +103,22 @@ def _tunable_to_configspace( log=bool(tunable.is_log), q=nullable(int, tunable.quantization), distribution=distribution, - default=( - int(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None - ), - meta=meta, + default=(int(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None), + meta=meta ) elif tunable.type == "float": range_hp = Float( name=tunable.name, bounds=tunable.range, log=bool(tunable.is_log), - q=tunable.quantization, # type: ignore[arg-type] + q=tunable.quantization, # type: ignore[arg-type] distribution=distribution, # type: ignore[arg-type] - default=( - float(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None - ), - meta=meta, + default=(float(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None), + meta=meta ) else: raise TypeError(f"Invalid Parameter Type: {tunable.type}") @@ -142,37 +136,31 @@ def _tunable_to_configspace( # Create three hyperparameters: one for regular values, # one for special values, and one to choose between the two. (special_name, type_name) = special_param_names(tunable.name) - conf_space = ConfigurationSpace( - { - tunable.name: range_hp, - special_name: CategoricalHyperparameter( - name=special_name, - choices=tunable.special, - weights=special_weights, - default_value=tunable.default if tunable.default in tunable.special else None, - meta=meta, - ), - type_name: CategoricalHyperparameter( - name=type_name, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], - weights=switch_weights, - default_value=TunableValueKind.SPECIAL, - ), - } - ) - conf_space.add_condition( - EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL) - ) - conf_space.add_condition( - EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE) - ) + conf_space = ConfigurationSpace({ + tunable.name: range_hp, + special_name: CategoricalHyperparameter( + name=special_name, + choices=tunable.special, + weights=special_weights, + default_value=tunable.default if tunable.default in tunable.special else None, + meta=meta + ), + type_name: CategoricalHyperparameter( + name=type_name, + choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + weights=switch_weights, + default_value=TunableValueKind.SPECIAL, + ), + }) + conf_space.add_condition(EqualsCondition( + conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)) + conf_space.add_condition(EqualsCondition( + conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)) return conf_space -def tunable_groups_to_configspace( - tunables: TunableGroups, seed: Optional[int] = None -) -> ConfigurationSpace: +def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -190,14 +178,11 @@ def tunable_groups_to_configspace( A new ConfigurationSpace instance that corresponds to the input TunableGroups. """ space = ConfigurationSpace(seed=seed) - for tunable, group in tunables: + for (tunable, group) in tunables: space.add_configuration_space( - prefix="", - delimiter="", + prefix="", delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost() - ), - ) + tunable, group.name, group.get_current_cost())) return space @@ -216,7 +201,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: A ConfigSpace Configuration. """ values: Dict[str, TunableValue] = {} - for tunable, _group in tunables: + for (tunable, _group) in tunables: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) if tunable.value in tunable.special: @@ -237,7 +222,10 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: In particular, remove and keys suffixes added by `special_param_names`. """ data = data.copy() - specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)] + specials = [ + special_param_name_strip(k) + for k in data.keys() if special_param_name_is_temp(k) + ] for k in specials: (special_name, type_name) = special_param_names(k) if data[type_name] == TunableValueKind.SPECIAL: diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 6e5700a37d..4f207f5fc9 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -28,13 +28,11 @@ class GridSearchOptimizer(TrackBestOptimizer): Grid search optimizer. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) # Track the grid as a set of tuples of tunable values and reconstruct the @@ -53,19 +51,11 @@ def __init__( def _sanity_check(self) -> None: size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) if size == np.inf: - raise ValueError( - f"Unquantized tunables are not supported for grid search: {self._tunables}" - ) + raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") if size > 10000: - _LOG.warning( - "Large number %d of config points requested for grid search: %s", - size, - self._tunables, - ) + _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) if size > self._max_iter: - _LOG.warning( - "Grid search size %d, is greater than max iterations %d", size, self._max_iter - ) + _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: """ @@ -78,14 +68,12 @@ def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], Non # names instead of the order given by TunableGroups. configs = [ configspace_data_to_tunable_values(dict(config)) - for config in generate_grid( - self.config_space, - { - tunable.name: int(tunable.cardinality) - for (tunable, _group) in self._tunables - if tunable.quantization or tunable.type == "int" - }, - ) + for config in + generate_grid(self.config_space, { + tunable.name: int(tunable.cardinality) + for (tunable, _group) in self._tunables + if tunable.quantization or tunable.type == "int" + }) ] names = set(tuple(configs.keys()) for configs in configs) assert len(names) == 1 @@ -115,17 +103,15 @@ def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: # See NOTEs above. return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for params, score, trial_status in zip(configs, scores, status): + for (params, score, trial_status) in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -166,32 +152,20 @@ def suggest(self) -> TunableGroups: _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) try: - config = dict( - ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()) - ) + config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) self._suggested_configs.remove(tuple(config.values())) except KeyError: - _LOG.warning( - "Attempted to remove missing config (previously registered?) from suggested set: %s", - tunables, - ) + _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) return registered_score def not_converged(self) -> bool: if self._iter > self._max_iter: if bool(self._pending_configs): - _LOG.warning( - "Exceeded max iterations, but still have %d pending configs: %s", - len(self._pending_configs), - list(self._pending_configs.keys()), - ) + _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", + len(self._pending_configs), list(self._pending_configs.keys())) return False return bool(self._pending_configs) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index a13ebe8d10..d7d50f1ca5 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -40,41 +40,35 @@ class MlosCoreOptimizer(Optimizer): A wrapper class for the mlos_core optimizers. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) - opt_type = getattr( - OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name) - ) + opt_type = getattr(OptimizerType, self._config.pop( + 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name)) if opt_type == OptimizerType.SMAC: - output_directory = self._config.get("output_directory") + output_directory = self._config.get('output_directory') if output_directory is not None: # If output_directory is specified, turn it into an absolute path. - self._config["output_directory"] = os.path.abspath(output_directory) + self._config['output_directory'] = os.path.abspath(output_directory) else: - _LOG.warning( - "SMAC optimizer output_directory was null. SMAC will use a temporary directory." - ) + _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.") # Make sure max_trials >= max_iterations. - if "max_trials" not in self._config: - self._config["max_trials"] = self._max_iter - assert ( - int(self._config["max_trials"]) >= self._max_iter - ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" + if 'max_trials' not in self._config: + self._config['max_trials'] = self._max_iter + assert int(self._config['max_trials']) >= self._max_iter, \ + f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" - if "run_name" not in self._config and self.experiment_id: - self._config["run_name"] = self.experiment_id + if 'run_name' not in self._config and self.experiment_id: + self._config['run_name'] = self.experiment_id - space_adapter_type = self._config.pop("space_adapter_type", None) - space_adapter_config = self._config.pop("space_adapter_config", {}) + space_adapter_type = self._config.pop('space_adapter_type', None) + space_adapter_config = self._config.pop('space_adapter_config', {}) if space_adapter_type is not None: space_adapter_type = getattr(SpaceAdapterType, space_adapter_type) @@ -88,12 +82,9 @@ def __init__( space_adapter_kwargs=space_adapter_config, ) - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: self._opt.cleanup() return super().__exit__(ex_type, ex_val, ex_tb) @@ -101,12 +92,10 @@ def __exit__( def name(self) -> str: return f"{self.__class__.__name__}:{self._opt.__class__.__name__}" - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False @@ -114,8 +103,7 @@ def bulk_register( df_configs = self._to_df(configs) # Impute missing values, if necessary df_scores = self._adjust_signs_df( - pd.DataFrame([{} if score is None else score for score in scores]) - ) + pd.DataFrame([{} if score is None else score for score in scores])) opt_targets = list(self._opt_targets) if status is not None: @@ -142,7 +130,7 @@ def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: """ In-place adjust the signs of the scores for MINIMIZATION problem. """ - for opt_target, opt_dir in self._opt_targets.items(): + for (opt_target, opt_dir) in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores @@ -164,7 +152,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: df_configs = pd.DataFrame(configs) tunables_names = list(self._tunables.get_param_values().keys()) missing_cols = set(tunables_names).difference(df_configs.columns) - for tunable, _group in self._tunables: + for (tunable, _group) in self._tunables: if tunable.name in missing_cols: df_configs[tunable.name] = tunable.default else: @@ -196,31 +184,22 @@ def suggest(self) -> TunableGroups: df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults) self._start_with_defaults = False _LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config) - return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict())) - - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: - registered_score = super().register( - tunables, status, score - ) # Sign-adjusted for MINIMIZATION + return tunables.assign( + configspace_data_to_tunable_values(df_config.loc[0].to_dict())) + + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None df_config = self._to_df([tunables.get_param_values()]) _LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config) # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register( - configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) - ) + self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float)) return registered_score - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: (df_config, df_score, _df_context) = self._opt.get_best_observations() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 8dd13eb182..ada4411b58 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -24,13 +24,11 @@ class MockOptimizer(TrackBestOptimizer): Mock optimizer to test the Environment API. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) rnd = random.Random(self.seed) self._random: Dict[str, Callable[[Tunable], TunableValue]] = { @@ -39,17 +37,15 @@ def __init__( "int": lambda tunable: rnd.randint(*tunable.range), } - def bulk_register( - self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None, - ) -> bool: + def bulk_register(self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for params, score, trial_status in zip(configs, scores, status): + for (params, score, trial_status) in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -66,7 +62,7 @@ def suggest(self) -> TunableGroups: _LOG.info("Use default tunable values") self._start_with_defaults = False else: - for tunable, _group in tunables: + for (tunable, _group) in tunables: tunable.value = self._random[tunable.type](tunable) _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index b7a14f8af2..9ad1070c46 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -24,13 +24,11 @@ class OneShotOptimizer(MockOptimizer): # TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344 - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) _LOG.info("Run a single iteration for: %s", self._tunables) self._max_iter = 1 # Always run for just one iteration. diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index 0fd54b2dfa..32a23142e3 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -24,23 +24,17 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): Base Optimizer class that keeps track of the best score and configuration. """ - def __init__( - self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(tunables, config, global_config, service) self._best_config: Optional[TunableGroups] = None self._best_score: Optional[Dict[str, float]] = None - def register( - self, - tunables: TunableGroups, - status: Status, - score: Optional[Dict[str, TunableValue]] = None, - ) -> Optional[Dict[str, float]]: + def register(self, tunables: TunableGroups, status: Status, + score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) if status.is_succeeded() and self._is_better(registered_score): self._best_score = registered_score @@ -54,7 +48,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: if self._best_score is None: return True assert registered_score is not None - for opt_target, best_score in self._best_score.items(): + for (opt_target, best_score) in self._best_score.items(): score = registered_score[opt_target] if score < best_score: return True @@ -62,9 +56,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: return False return False - def get_best_observation( - self, - ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) score = self._get_scores(Status.SUCCEEDED, self._best_score) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index 7f26851c6b..a7912688a1 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,19 +22,16 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[ - str - ] # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object else: - EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access + EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 -if sys.platform == "win32": +if sys.platform == 'win32': import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) - environ: EnvironType = nt.environ else: environ: EnvironType = os.environ -__all__ = ["environ"] +__all__ = ['environ'] diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py index 57c48a87b9..85c8c2b0c5 100755 --- a/mlos_bench/mlos_bench/run.py +++ b/mlos_bench/mlos_bench/run.py @@ -20,9 +20,8 @@ _LOG = logging.getLogger(__name__) -def _main( - argv: Optional[List[str]] = None, -) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: +def _main(argv: Optional[List[str]] = None + ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv) diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index c53d11231d..c54e3c0efc 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ - "Scheduler", - "SyncScheduler", + 'Scheduler', + 'SyncScheduler', ] diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 210e2784a5..0b6733e423 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -31,16 +31,13 @@ class Scheduler(metaclass=ABCMeta): Base class for the optimization loop scheduling policies. """ - def __init__( - self, - *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: Storage, - root_env_config: str, - ): + def __init__(self, *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: Storage, + root_env_config: str): """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service @@ -63,9 +60,8 @@ def __init__( Path to the root environment configuration. """ self.global_config = global_config - config = merge_parameters( - dest=config.copy(), source=global_config, required_keys=["experiment_id", "trial_id"] - ) + config = merge_parameters(dest=config.copy(), source=global_config, + required_keys=["experiment_id", "trial_id"]) self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) @@ -75,9 +71,7 @@ def __init__( self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) if self._trial_config_repeat_count <= 0: - raise ValueError( - f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" - ) + raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") self._do_teardown = bool(config.get("teardown", True)) @@ -101,7 +95,7 @@ def __repr__(self) -> str: """ return self.__class__.__name__ - def __enter__(self) -> "Scheduler": + def __enter__(self) -> 'Scheduler': """ Enter the scheduler's context. """ @@ -123,12 +117,10 @@ def __enter__(self) -> "Scheduler": ).__enter__() return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the context of the scheduler. """ @@ -150,12 +142,8 @@ def start(self) -> None: Start the optimization loop. """ assert self.experiment is not None - _LOG.info( - "START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, - self.environment, - self.optimizer, - ) + _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", + self.experiment, self.environment, self.optimizer) if _LOG.isEnabledFor(logging.INFO): _LOG.info("Root Environment:\n%s", self.environment.pprint()) @@ -216,33 +204,27 @@ def schedule_trial(self, tunables: TunableGroups) -> None: Add a configuration to the queue of trials. """ for repeat_i in range(1, self._trial_config_repeat_count + 1): - self._add_trial_to_queue( - tunables, - config={ - # Add some additional metadata to track for the trial such as the - # optimizer config used. - # Note: these values are unfortunately mutable at the moment. - # Consider them as hints of what the config was the trial *started*. - # It is possible that the experiment configs were changed - # between resuming the experiment (since that is not currently - # prevented). - "optimizer": self.optimizer.name, - "repeat_i": repeat_i, - "is_defaults": tunables.is_defaults, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(self.optimizer.targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - }, - }, - ) + self._add_trial_to_queue(tunables, config={ + # Add some additional metadata to track for the trial such as the + # optimizer config used. + # Note: these values are unfortunately mutable at the moment. + # Consider them as hints of what the config was the trial *started*. + # It is possible that the experiment configs were changed + # between resuming the experiment (since that is not currently + # prevented). + "optimizer": self.optimizer.name, + "repeat_i": repeat_i, + "is_defaults": tunables.is_defaults, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(self.optimizer.targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + } + }) - def _add_trial_to_queue( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> None: + def _add_trial_to_queue(self, tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> None: """ Add a configuration to the queue of trials. A wrapper for the `Experiment.new_trial` method. diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 3e196d4d4f..a73a493533 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -53,9 +53,7 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = ( - self.environment.run() - ) # Block and wait for the final result. + (status, timestamp, results) = self.environment.run() # Block and wait for the final result. _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index dacbb88126..bcc7d02d6f 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -11,7 +11,7 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - "Service", - "FileShareService", - "LocalExecService", + 'Service', + 'FileShareService', + 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index 63c222ee45..f00a7a1a00 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -21,13 +21,10 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): An abstract base of all file shares. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new file share with a given config. @@ -45,16 +42,12 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) @abstractmethod - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: """ Downloads contents from a remote share path to a local path. @@ -72,18 +65,11 @@ def download( if True (the default), download the entire directory tree. """ params = params or {} - _LOG.info( - "Download from File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", - remote_path, - local_path, - params, - ) + _LOG.info("Download from File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", remote_path, local_path, params) @abstractmethod - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: """ Uploads contents from a local path to remote share path. @@ -100,10 +86,5 @@ def upload( if True (the default), upload the entire directory tree. """ params = params or {} - _LOG.info( - "Upload to File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", - local_path, - remote_path, - params, - ) + _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", local_path, remote_path, params) diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 65725b6288..e7c9365bf7 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -26,13 +26,11 @@ class Service: """ @classmethod - def new( - cls, - class_name: str, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - ) -> "Service": + def new(cls, + class_name: str, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None) -> "Service": """ Factory method for a new service with a given config. @@ -59,13 +57,11 @@ def new( assert issubclass(cls, Service) return instantiate_from_config(cls, class_name, config, global_config, parent) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new service with a given config. @@ -105,10 +101,8 @@ def __init__( _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) @staticmethod - def merge_methods( - ext_methods: Union[Dict[str, Callable], List[Callable], None], - local_methods: Union[Dict[str, Callable], List[Callable]], - ) -> Dict[str, Callable]: + def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], + local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. This function is usually called by the derived class constructor @@ -144,12 +138,9 @@ def __enter__(self) -> "Service": self._in_context = True return self - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exit the Service mix-in context. @@ -186,12 +177,9 @@ def _enter_context(self) -> "Service": self._in_context = True return self - def _exit_context( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def _exit_context(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: """ Exits the context for this particular Service instance. @@ -277,11 +265,10 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None # Unfortunately, by creating a set, we may destroy the ability to # preserve the context enter/exit order, but hopefully it doesn't # matter. - svc_method.__self__ - for _, svc_method in self._service_methods.items() + svc_method.__self__ for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) + if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) } def export(self) -> Dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 55d8e67527..cac3216d61 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -61,13 +61,11 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of config persistence service. @@ -84,22 +82,17 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.resolve_path, - self.load_config, - self.prepare_class_load, - self.build_service, - self.build_environment, - self.load_services, - self.load_environment, - self.load_environment_list, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.resolve_path, + self.load_config, + self.prepare_class_load, + self.build_service, + self.build_environment, + self.load_services, + self.load_environment, + self.load_environment_list, + ]) ) self._config_loader_service = self @@ -127,7 +120,8 @@ def config_paths(self) -> List[str]: """ return list(self._config_path) # make a copy to avoid modifications - def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, + extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -157,11 +151,10 @@ def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = No _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config( - self, - json_file_name: str, - schema_type: Optional[ConfigSchema], - ) -> Dict[str, Any]: + def load_config(self, + json_file_name: str, + schema_type: Optional[ConfigSchema], + ) -> Dict[str, Any]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -181,22 +174,16 @@ def load_config( """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) - with open(json_file_name, mode="r", encoding="utf-8") as fh_json: + with open(json_file_name, mode='r', encoding='utf-8') as fh_json: config = json5.load(fh_json) if schema_type is not None: try: schema_type.validate(config) except (ValidationError, SchemaError) as ex: - _LOG.error( - "Failed to validate config %s against schema type %s at %s", - json_file_name, - schema_type.name, - schema_type.value, - ) - raise ValueError( - f"Failed to validate config {json_file_name} against " - + f"schema type {schema_type.name} at {schema_type.value}" - ) from ex + _LOG.error("Failed to validate config %s against schema type %s at %s", + json_file_name, schema_type.name, schema_type.value) + raise ValueError(f"Failed to validate config {json_file_name} against " + + f"schema type {schema_type.name} at {schema_type.value}") from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated # them to avoid passing them on to other objects @@ -207,14 +194,11 @@ def load_config( del config["$schema"] else: _LOG.warning("Config %s is not validated against a schema.", json_file_name) - return config # type: ignore[no-any-return] + return config # type: ignore[no-any-return] - def prepare_class_load( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - ) -> Tuple[str, Dict[str, Any]]: + def prepare_class_load(self, config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. Mix-in the global parameters and resolve the local file system paths, @@ -257,22 +241,16 @@ def prepare_class_load( raise ValueError(f"Parameter {key} must be a string or a list") if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Instantiating: %s with config:\n%s", - class_name, - json.dumps(class_config, indent=2), - ) + _LOG.debug("Instantiating: %s with config:\n%s", + class_name, json.dumps(class_config, indent=2)) return (class_name, class_config) - def build_optimizer( - self, - *, - tunables: TunableGroups, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - ) -> Optimizer: + def build_optimizer(self, *, + tunables: TunableGroups, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None) -> Optimizer: """ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. @@ -301,24 +279,18 @@ def build_optimizer( if tunables_path is not None: tunables = self._load_tunables(tunables_path, tunables) (class_name, class_config) = self.prepare_class_load(config, global_config) - inst = instantiate_from_config( - Optimizer, - class_name, # type: ignore[type-abstract] - tunables=tunables, - config=class_config, - global_config=global_config, - service=service, - ) + inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract] + tunables=tunables, + config=class_config, + global_config=global_config, + service=service) _LOG.info("Created: Optimizer %s", inst) return inst - def build_storage( - self, - *, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - ) -> "Storage": + def build_storage(self, *, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None) -> "Storage": """ Instantiation of mlos_bench Storage objects. @@ -340,27 +312,20 @@ def build_storage( from mlos_bench.storage.base_storage import ( Storage, # pylint: disable=import-outside-toplevel ) - - inst = instantiate_from_config( - Storage, - class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - service=service, - ) + inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + service=service) _LOG.info("Created: Storage %s", inst) return inst - def build_scheduler( - self, - *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: "Storage", - root_env_config: str, - ) -> "Scheduler": + def build_scheduler(self, *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: "Storage", + root_env_config: str) -> "Scheduler": """ Instantiation of mlos_bench Scheduler. @@ -388,28 +353,22 @@ def build_scheduler( from mlos_bench.schedulers.base_scheduler import ( Scheduler, # pylint: disable=import-outside-toplevel ) - - inst = instantiate_from_config( - Scheduler, - class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - environment=environment, - optimizer=optimizer, - storage=storage, - root_env_config=root_env_config, - ) + inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] + config=class_config, + global_config=global_config, + environment=environment, + optimizer=optimizer, + storage=storage, + root_env_config=root_env_config) _LOG.info("Created: Scheduler %s", inst) return inst - def build_environment( - self, # pylint: disable=too-many-arguments - config: Dict[str, Any], - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> Environment: + def build_environment(self, # pylint: disable=too-many-arguments + config: Dict[str, Any], + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> Environment: """ Factory method for a new environment with a given config. @@ -449,24 +408,16 @@ def build_environment( tunables = self._load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new( - env_name=env_name, - class_name=env_class, - config=env_config, - global_config=global_config, - tunables=tunables, - service=service, - ) + env = Environment.new(env_name=env_name, class_name=env_class, + config=env_config, global_config=global_config, + tunables=tunables, service=service) _LOG.info("Created env: %s :: %s", env_name, env) return env - def _build_standalone_service( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def _build_standalone_service(self, config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -491,12 +442,9 @@ def _build_standalone_service( _LOG.info("Created service: %s", service) return service - def _build_composite_service( - self, - config_list: Iterable[Dict[str, Any]], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -522,21 +470,18 @@ def _build_composite_service( service.register(parent.export()) for config in config_list: - service.register( - self._build_standalone_service(config, global_config, service).export() - ) + service.register(self._build_standalone_service( + config, global_config, service).export()) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Created mix-in service: %s", service) return service - def build_service( - self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def build_service(self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Factory method for a new service with a given config. @@ -558,7 +503,8 @@ def build_service( services from the list plus the parent mix-in. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2)) + _LOG.debug("Build service from config:\n%s", + json.dumps(config, indent=2)) assert isinstance(config, dict) config_list: List[Dict[str, Any]] @@ -573,14 +519,12 @@ def build_service( return self._build_composite_service(config_list, global_config, parent) - def load_environment( - self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> Environment: + def load_environment(self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> Environment: """ Load and build new environment from the config file. @@ -607,14 +551,12 @@ def load_environment( assert isinstance(config, dict) return self.build_environment(config, tunables, global_config, parent_args, service) - def load_environment_list( - self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None, - ) -> List[Environment]: + def load_environment_list(self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -639,14 +581,13 @@ def load_environment_list( A list of new benchmarking environments. """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) - return [self.build_environment(config, tunables, global_config, parent_args, service)] + return [ + self.build_environment(config, tunables, global_config, parent_args, service) + ] - def load_services( - self, - json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - ) -> Service: + def load_services(self, json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -665,16 +606,16 @@ def load_services( service : Service A collection of service methods. """ - _LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__) + _LOG.info("Load services: %s parent: %s", + json_file_names, parent.__class__.__name__) service = Service({}, global_config, parent) for fname in json_file_names: config = self.load_config(fname, ConfigSchema.SERVICE) service.register(self.build_service(config, global_config, service).export()) return service - def _load_tunables( - self, json_file_names: Iterable[str], parent: TunableGroups - ) -> TunableGroups: + def _load_tunables(self, json_file_names: Iterable[str], + parent: TunableGroups) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent TunableGroup. diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index b9d0c267c1..abb87c8b52 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -9,5 +9,5 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - "LocalExecService", + 'LocalExecService', ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 0486ab7c80..47534be7b1 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -79,13 +79,11 @@ class LocalExecService(TempDirContextService, SupportsLocalExec): due to reduced dependency management complications vs the target environment. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of a service to run scripts locally. @@ -102,16 +100,14 @@ def __init__( New methods to register with the service. """ super().__init__( - config, global_config, parent, self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, + self.merge_methods(methods, [self.local_exec]) ) self.abort_on_error = self.config.get("abort_on_error", True) - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -179,9 +175,9 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: subcmd_tokens.insert(0, sys.executable) return subcmd_tokens - def _local_exec_script( - self, script_line: str, env_params: Optional[Mapping[str, "TunableValue"]], cwd: str - ) -> Tuple[int, str, str]: + def _local_exec_script(self, script_line: str, + env_params: Optional[Mapping[str, "TunableValue"]], + cwd: str) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -210,7 +206,7 @@ def _local_exec_script( if env_params: env = {key: str(val) for (key, val) in env_params.items()} - if sys.platform == "win32": + if sys.platform == 'win32': # A hack to run Python on Windows with env variables set: env_copy = environ.copy() env_copy["PYTHONPATH"] = "" @@ -218,7 +214,7 @@ def _local_exec_script( env = env_copy try: - if sys.platform != "win32": + if sys.platform != 'win32': cmd = [" ".join(cmd)] _LOG.info("Run: %s", cmd) @@ -226,15 +222,8 @@ def _local_exec_script( _LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env)) _LOG.debug("Current working dir: %s", cwd) - proc = subprocess.run( - cmd, - env=env or None, - cwd=cwd, - shell=True, - text=True, - check=False, - capture_output=True, - ) + proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, + text=True, check=False, capture_output=True) _LOG.debug("Run: return code = %d", proc.returncode) return (proc.returncode, proc.stdout, proc.stderr) diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index 8512b5d282..a0cf3e0e57 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -28,13 +28,11 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): This class is not supposed to be used as a standalone service. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of a service that provides temporary directory context for local exec service. @@ -52,7 +50,8 @@ def __init__( New methods to register with the service. """ super().__init__( - config, global_config, parent, self.merge_methods(methods, [self.temp_dir_context]) + config, global_config, parent, + self.merge_methods(methods, [self.temp_dir_context]) ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: @@ -62,9 +61,7 @@ def __init__( self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir) _LOG.info("%s: temp dir: %s", self, self._temp_dir) - def temp_dir_context( - self, path: Optional[str] = None - ) -> Union[TemporaryDirectory, nullcontext]: + def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 12fe62eeb7..61a6c74942 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -13,9 +13,9 @@ from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService __all__ = [ - "AzureAuthService", - "AzureFileShareService", - "AzureNetworkService", - "AzureSaaSConfigService", - "AzureVMService", + 'AzureAuthService', + 'AzureFileShareService', + 'AzureNetworkService', + 'AzureSaaSConfigService', + 'AzureVMService', ] diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 9074353221..4121446caf 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -27,15 +27,13 @@ class AzureAuthService(Service, SupportsAuth): Helper methods to get access to Azure services. """ - _REQ_INTERVAL = 300 # = 5 min - - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + _REQ_INTERVAL = 300 # = 5 min + + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure authentication services proxy. @@ -52,16 +50,11 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.get_access_token, - self.get_auth_headers, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.get_access_token, + self.get_auth_headers, + ]) ) # This parameter can come from command line as strings, so conversion is needed. @@ -77,13 +70,12 @@ def __init__( # Verify info required for SP auth early if "spClientId" in self.config: check_required_params( - self.config, - { + self.config, { "spClientId", "keyVaultName", "certName", "tenant", - }, + } ) def _init_sp(self) -> None: @@ -112,9 +104,7 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential( - tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes - ) + self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) def get_access_token(self) -> str: """ diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 3673baca76..9f2b504aff 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -29,9 +29,9 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): Helper methods to manage and deploy Azure resources via REST APIs. """ - _POLL_INTERVAL = 4 # seconds - _POLL_TIMEOUT = 300 # seconds - _REQUEST_TIMEOUT = 5 # seconds + _POLL_INTERVAL = 4 # seconds + _POLL_TIMEOUT = 300 # seconds + _REQUEST_TIMEOUT = 5 # seconds _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) @@ -39,21 +39,19 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): # https://docs.microsoft.com/en-us/rest/api/resources/deployments _URL_DEPLOY = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Resources" - + "/deployments/{deployment_name}" - + "?api-version=2022-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Resources" + + "/deployments/{deployment_name}" + + "?api-version=2022-05-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of an Azure Services proxy. @@ -71,44 +69,32 @@ def __init__( """ super().__init__(config, global_config, parent, methods) - check_required_params( - self.config, - [ - "subscription", - "resourceGroup", - ], - ) + check_required_params(self.config, [ + "subscription", + "resourceGroup", + ]) # These parameters can come from command line as strings, so conversion is needed. self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - self._total_retries = int( - self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) - ) - self._backoff_factor = float( - self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) - ) + self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) + self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) self._deploy_template = {} self._deploy_params = {} if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config["deploymentTemplatePath"], schema_type=None - ) + self.config['deploymentTemplatePath'], schema_type=None) assert template is not None and isinstance(template, dict) self._deploy_template = template # Allow for recursive variable expansion as we do with global params and const_args. - deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( - extra_source_dict=global_config - ) + deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) else: - _LOG.info( - "No deploymentTemplatePath provided. Deployment services will be unavailable." - ) + _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") @property def deploy_params(self) -> dict: @@ -143,8 +129,7 @@ def _get_session(self, params: dict) -> requests.Session: session = requests.Session() session.mount( "https://", - HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)), - ) + HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) session.headers.update(self._get_headers()) return session @@ -152,9 +137,8 @@ def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ + "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @staticmethod @@ -250,11 +234,9 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: return (Status.FAILED, {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") if response.status_code == 200: output = response.json() @@ -287,16 +269,12 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ params = self._set_default_params(params) - _LOG.info( - "Wait for %s to %s", - params.get("deploymentName"), - "provision" if is_setup else "deprovision", - ) + _LOG.info("Wait for %s to %s", params.get("deploymentName"), + "provision" if is_setup else "deprovision") return self._wait_while(self._check_deployment, Status.PENDING, params) - def _wait_while( - self, func: Callable[[dict], Tuple[Status, dict]], loop_status: Status, params: dict - ) -> Tuple[Status, dict]: + def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], + loop_status: Status, params: dict) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -318,18 +296,12 @@ def _wait_while( """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] - ) + dest=self.config.copy(), source=params, required_keys=["deploymentName"]) poll_period = params.get("pollInterval", self._poll_interval) - _LOG.debug( - "Wait for %s status %s :: poll %.2f timeout %d s", - config["deploymentName"], - loop_status, - poll_period, - self._poll_timeout, - ) + _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", + config["deploymentName"], loop_status, poll_period, self._poll_timeout) ts_timeout = time.time() + self._poll_timeout poll_delay = poll_period @@ -353,9 +325,7 @@ def _wait_while( _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment( - self, params: dict - ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -381,7 +351,7 @@ def _check_deployment( "subscription", "resourceGroup", "deploymentName", - ], + ] ) _LOG.info("Check deployment: %s", config["deploymentName"]) @@ -442,18 +412,13 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: if not self._deploy_template: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) - config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] - ) + config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) params = merge_parameters(dest=self._deploy_params.copy(), source=params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Deploy: %s merged params ::\n%s", - config["deploymentName"], - json.dumps(params, indent=2), - ) + _LOG.debug("Deploy: %s merged params ::\n%s", + config["deploymentName"], json.dumps(params, indent=2)) url = self._URL_DEPLOY.format( subscription=config["subscription"], @@ -466,26 +431,22 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: "mode": "Incremental", "template": self._deploy_template, "parameters": { - key: {"value": val} - for (key, val) in params.items() + key: {"value": val} for (key, val) in params.items() if key in self._deploy_template.get("parameters", {}) - }, + } } } if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.put( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout - ) + response = requests.put(url, json=json_req, + headers=self._get_headers(), timeout=self._request_timeout) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") else: _LOG.info("Response: %s", response) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 717086b52e..6ccd4ba09d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -27,13 +27,11 @@ class AzureFileShareService(FileShareService): _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new file share Service for Azure environments with a given config. @@ -51,19 +49,16 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) check_required_params( - self.config, - { + self.config, { "storageAccountName", "storageFileShareName", "storageAccountKey", - }, + } ) self._share_client = ShareClient.from_share_url( @@ -74,9 +69,7 @@ def __init__( credential=self.config["storageAccountKey"], ) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -101,9 +94,7 @@ def download( # Translate into non-Azure exception: raise FileNotFoundError(f"Cannot download: {remote_path}") from ex - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 4ba8bd3903..d65ee02cfd 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -32,22 +32,20 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 _URL_DEPROVISION = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Network" - + "/virtualNetwork/{vnet_name}" - + "/delete" - + "?api-version=2023-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Network" + + "/virtualNetwork/{vnet_name}" + + "/delete" + + "?api-version=2023-05-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure Network services proxy. @@ -64,34 +62,25 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - # SupportsNetworkProvisioning - self.provision_network, - self.deprovision_network, - self.wait_network_deployment, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + # SupportsNetworkProvisioning + self.provision_network, + self.deprovision_network, + self.wait_network_deployment, + ]) ) if not self._deploy_template: - raise ValueError( - "AzureNetworkService requires a deployment template:\n" - + f"config={config}\nglobal_config={global_config}" - ) + raise ValueError("AzureNetworkService requires a deployment template:\n" + + f"config={config}\nglobal_config={global_config}") - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" - _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] - ) + _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) return params def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -162,18 +151,15 @@ def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple "resourceGroup", "deploymentName", "vnetName", - ], + ] ) _LOG.info("Deprovision Network: %s", config["vnetName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) - (status, results) = self._azure_rest_api_post_helper( - config, - self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vnet_name=config["vnetName"], - ), - ) + (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vnet_name=config["vnetName"], + )) if ignore_errors and status == Status.FAILED: _LOG.warning("Ignoring error: %s", results) status = Status.SUCCEEDED diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index e7f626f505..a92d279a6d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -32,22 +32,20 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations _URL_CONFIGURE = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/{provider}" - + "/{server_type}/{vm_name}" - + "/{update}" - + "?api-version={api_version}" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/{provider}" + + "/{server_type}/{vm_name}" + + "/{update}" + + "?api-version={api_version}" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure services proxy. @@ -64,20 +62,18 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.configure, self.is_config_pending]), + config, global_config, parent, + self.merge_methods(methods, [ + self.configure, + self.is_config_pending + ]) ) - check_required_params( - self.config, - { - "subscription", - "resourceGroup", - "provider", - }, - ) + check_required_params(self.config, { + "subscription", + "resourceGroup", + "provider", + }) # Provide sane defaults for known DB providers. provider = self.config.get("provider") @@ -121,7 +117,8 @@ def __init__( # These parameters can come from command line as strings, so conversion is needed. self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: + def configure(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service. @@ -159,38 +156,33 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ - config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters( + dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_get.format(vm_name=config["vmName"]) _LOG.debug("Request: GET %s", url) - response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put( + url, headers=self._get_headers(), timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) if response.status_code != 200: return (Status.FAILED, {}) # Currently, Azure Flex servers require a VM reboot. - return ( - Status.SUCCEEDED, - { - "isConfigPendingReboot": any( - {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]] - for val in response.json()["value"] - ) - }, - ) + return (Status.SUCCEEDED, {"isConfigPendingReboot": any( + {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] + for val in response.json()['value'] + )}) def _get_headers(self) -> dict: """ Get the headers for the REST API calls. """ - assert self._parent is not None and isinstance( - self._parent, SupportsAuth - ), "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ + "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() - def _config_one( - self, config: Dict[str, Any], param_name: str, param_value: Any - ) -> Tuple[Status, dict]: + def _config_one(self, config: Dict[str, Any], + param_name: str, param_value: Any) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -209,15 +201,13 @@ def _config_one( A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters( + dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) _LOG.debug("Request: PUT %s", url) - response = requests.put( - url, - headers=self._get_headers(), - json={"properties": {"value": str(param_value)}}, - timeout=self._request_timeout, - ) + response = requests.put(url, headers=self._get_headers(), + json={"properties": {"value": str(param_value)}}, + timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) @@ -225,7 +215,8 @@ def _config_one( return (Status.SUCCEEDED, {}) return (Status.FAILED, {}) - def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_many(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service one-by-one. (If batch API is not available for it). @@ -243,13 +234,14 @@ def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[ A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - for param_name, param_value in params.items(): + for (param_name, param_value) in params.items(): (status, result) = self._config_one(config, param_name, param_value) if not status.is_succeeded(): return (status, result) return (Status.SUCCEEDED, {}) - def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_batch(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple[Status, dict]: """ Batch update the parameters of an Azure DB service. @@ -266,18 +258,19 @@ def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters( + dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"]) json_req = { "value": [ - {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items() + {"name": key, "properties": {"value": str(val)}} + for (key, val) in params.items() ], # "resetAllToDefault": "True" } _LOG.debug("Request: POST %s", url) - response = requests.post( - url, headers=self._get_headers(), json=json_req, timeout=self._request_timeout - ) + response = requests.post(url, headers=self._get_headers(), + json=json_req, timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index effb0f9499..ddce3cc935 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -26,13 +26,7 @@ _LOG = logging.getLogger(__name__) -class AzureVMService( - AzureDeploymentService, - SupportsHostProvisioning, - SupportsHostOps, - SupportsOSOps, - SupportsRemoteExec, -): +class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): """ Helper methods to manage VMs on Azure. """ @@ -44,35 +38,35 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/start" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/start" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/powerOff" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/powerOff" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/deallocate" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/deallocate" + + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -94,33 +88,31 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/restart" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/restart" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( - "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/runCommand" - + "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/runCommand" + + "?api-version=2022-03-01" ) - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of Azure VM services proxy. @@ -137,31 +129,26 @@ def __init__( New methods to register with the service. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - # SupportsHostProvisioning - self.provision_host, - self.deprovision_host, - self.deallocate_host, - self.wait_host_deployment, - # SupportsHostOps - self.start_host, - self.stop_host, - self.restart_host, - self.wait_host_operation, - # SupportsOSOps - self.shutdown, - self.reboot, - self.wait_os_operation, - # SupportsRemoteExec - self.remote_exec, - self.get_remote_exec_results, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + # SupportsHostProvisioning + self.provision_host, + self.deprovision_host, + self.deallocate_host, + self.wait_host_deployment, + # SupportsHostOps + self.start_host, + self.stop_host, + self.restart_host, + self.wait_host_operation, + # SupportsOSOps + self.shutdown, + self.reboot, + self.wait_os_operation, + # SupportsRemoteExec + self.remote_exec, + self.get_remote_exec_results, + ]) ) # As a convenience, allow reading customData out of a file, rather than @@ -170,23 +157,19 @@ def __init__( # can be done using the `base64()` string function inside the ARM template. self._custom_data_file = self.config.get("customDataFile", None) if self._custom_data_file: - if self._deploy_params.get("customData", None): + if self._deploy_params.get('customData', None): raise ValueError("Both customDataFile and customData are specified.") - self._custom_data_file = self.config_loader_service.resolve_path( - self._custom_data_file - ) - with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: + self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file) + with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh: self._deploy_params["customData"] = custom_data_fh.read() - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" - _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] - ) + _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) return params def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -281,19 +264,16 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: "resourceGroup", "deploymentName", "vmName", - ], + ] ) _LOG.info("Deprovision VM: %s", config["vmName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) # TODO: Properly deprovision *all* resources specified in the ARM template. - return self._azure_rest_api_post_helper( - config, - self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -321,17 +301,14 @@ def deallocate_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Deallocate VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_DEALLOCATE.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def start_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -356,17 +333,14 @@ def start_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Start VM: %s :: %s", config["vmName"], params) - return self._azure_rest_api_post_helper( - config, - self._URL_START.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_START.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: """ @@ -393,17 +367,14 @@ def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Stop VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_STOP.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_STOP.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.stop_host(params, force) @@ -433,24 +404,20 @@ def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict] "subscription", "resourceGroup", "vmName", - ], + ] ) _LOG.info("Reboot VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper( - config, - self._URL_REBOOT.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - ), - ) + return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + )) def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple[Status, dict]: + def remote_exec(self, script: Iterable[str], config: dict, + env_params: dict) -> Tuple[Status, dict]: """ Run a command on Azure VM. @@ -480,7 +447,7 @@ def remote_exec( "subscription", "resourceGroup", "vmName", - ], + ] ) if _LOG.isEnabledFor(logging.INFO): @@ -489,7 +456,7 @@ def remote_exec( json_req = { "commandId": "RunShellScript", "script": list(script), - "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()], + "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()] } url = self._URL_REXEC_RUN.format( @@ -502,15 +469,12 @@ def remote_exec( _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout - ) + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Response: %s\n%s", - response, - json.dumps(response.json(), indent=2) if response.content else "", - ) + _LOG.debug("Response: %s\n%s", response, + json.dumps(response.json(), indent=2) + if response.content else "") else: _LOG.info("Response: %s", response) @@ -518,10 +482,10 @@ def remote_exec( # TODO: extract the results from JSON response return (Status.SUCCEEDED, config) elif response.status_code == 202: - return ( - Status.PENDING, - {**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")}, - ) + return (Status.PENDING, { + **config, + "asyncResultsUrl": response.headers.get("Azure-AsyncOperation") + }) else: _LOG.error("Response: %s :: %s", response, response.text) # _LOG.error("Bad Request:\n%s", response.request.body) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f136747f7f..f623cdfcc8 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -31,14 +31,9 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" - async def _start_file_copy( - self, - params: dict, - mode: CopyMode, - local_path: str, - remote_path: str, - recursive: bool = True, - ) -> None: + async def _start_file_copy(self, params: dict, mode: CopyMode, + local_path: str, remote_path: str, + recursive: bool = True) -> None: # pylint: disable=too-many-arguments """ Starts a file copy operation @@ -78,52 +73,40 @@ async def _start_file_copy( raise ValueError(f"Unknown copy mode: {mode}") return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ], + ] ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive) - ) + self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error( - "Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex - ) + _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) if isinstance(ex, SFTPNoSuchFile) or ( - isinstance(ex, SFTPFailure) - and ex.code == 4 - and any( - msg.lower() in ex.reason.lower() - for msg in ("File not found", "No such file or directory") - ) + isinstance(ex, SFTPFailure) and ex.code == 4 + and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) ): _LOG.warning("File %s does not exist on %s", remote_path, params) raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex raise ex - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ], + ] ) super().upload(params, local_path, remote_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive) - ) + self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) try: file_copy_future.result() except (OSError, SFTPError) as ex: diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index f04544eb05..a650ff0707 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -29,13 +29,11 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of an SSH Service. @@ -54,25 +52,17 @@ def __init__( # Same methods are also provided by the AzureVMService class # pylint: disable=duplicate-code super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.shutdown, - self.reboot, - self.wait_os_operation, - self.remote_exec, - self.get_remote_exec_results, - ], - ), - ) + config, global_config, parent, + self.merge_methods(methods, [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ])) self._shell = self.config.get("ssh_shell", "/bin/bash") - async def _run_cmd( - self, params: dict, script: Iterable[str], env_params: dict - ) -> SSHCompletedProcess: + async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. @@ -95,19 +85,16 @@ async def _run_cmd( # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. # Handle transferring environment variables by making a script to set them. env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] - script_lines = env_script_lines + [ - line_split for line in script for line_split in line.splitlines() - ] + script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] # Note: connection.run() uses "exec" with a shell by default. - script_str = "\n".join(script_lines) + script_str = '\n'.join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) - return await connection.run( - script_str, check=False, timeout=self._request_timeout, env=env_params - ) + return await connection.run(script_str, + check=False, + timeout=self._request_timeout, + env=env_params) - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple["Status", dict]: + def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -134,11 +121,9 @@ def remote_exec( source=config, required_keys=[ "ssh_hostname", - ], - ) - config["asyncRemoteExecResultsFuture"] = self._run_coroutine( - self._run_cmd(config, script, env_params) + ] ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) return (Status.PENDING, config) def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: @@ -169,11 +154,7 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr return ( - ( - Status.SUCCEEDED - if result.exit_status == 0 and result.returncode == 0 - else Status.FAILED - ), + Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, { "stdout": stdout, "stderr": stderr, @@ -205,9 +186,9 @@ def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, d source=params, required_keys=[ "ssh_hostname", - ], + ] ) - cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) + cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) script = rf""" if [[ $EUID -ne 0 ]]; then sudo=$(command -v sudo) @@ -242,10 +223,10 @@ def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - "shutdown -h now", - "poweroff", - "halt -p", - "systemctl poweroff", + 'shutdown -h now', + 'poweroff', + 'halt -p', + 'systemctl poweroff', ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) @@ -267,11 +248,11 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - "shutdown -r now", - "reboot", - "halt --reboot", - "systemctl reboot", - "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", + 'shutdown -r now', + 'reboot', + 'halt --reboot', + 'systemctl reboot', + 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 64bb7d9788..8bc90eb3da 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -50,8 +50,8 @@ class SshClient(asyncssh.SSHClient): reconnect for each command. """ - _CONNECTION_PENDING = "INIT" - _CONNECTION_LOST = "LOST" + _CONNECTION_PENDING = 'INIT' + _CONNECTION_LOST = 'LOST' def __init__(self, *args: tuple, **kwargs: dict): self._connection_id: str = SshClient._CONNECTION_PENDING @@ -65,7 +65,7 @@ def __repr__(self) -> str: @staticmethod def id_from_connection(connection: SSHClientConnection) -> str: """Gets a unique id repr for the connection.""" - return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access @staticmethod def id_from_params(connect_params: dict) -> str: @@ -79,9 +79,8 @@ def connection_made(self, conn: SSHClientConnection) -> None: Changes the connection_id from _CONNECTION_PENDING to a unique id repr. """ self._conn_event.clear() - _LOG.debug( - "%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn - ) # pylint: disable=protected-access + _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ + # pylint: disable=protected-access self._connection_id = SshClient.id_from_connection(conn) self._connection = conn self._conn_event.set() @@ -91,19 +90,9 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self._conn_event.clear() _LOG.debug("%s: %s", current_thread().name, "connection_lost") if exc is None: - _LOG.debug( - "%s: gracefully disconnected ssh from %s: %s", - current_thread().name, - self._connection_id, - exc, - ) + _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) else: - _LOG.debug( - "%s: ssh connection lost on %s: %s", - current_thread().name, - self._connection_id, - exc, - ) + _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) self._connection_id = SshClient._CONNECTION_LOST self._connection = None self._conn_event.set() @@ -156,9 +145,7 @@ def exit(self) -> None: warn(RuntimeWarning("SshClientCache lock was still held on exit.")) self._cache_lock.release() - async def get_client_connection( - self, connect_params: dict - ) -> Tuple[SSHClientConnection, SshClient]: + async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. @@ -181,21 +168,13 @@ async def get_client_connection( _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) connection = await client.connection() if not connection: - _LOG.debug( - "%s: Removing stale client connection %s from cache.", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) self._cache.pop(connection_id) # Try to reconnect next. else: _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) if connection_id not in self._cache: - _LOG.debug( - "%s: Establishing client connection to %s", - current_thread().name, - connection_id, - ) + _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) connection, client = await asyncssh.create_connection(SshClient, **connect_params) assert isinstance(client, SshClient) self._cache[connection_id] = (connection, client) @@ -206,7 +185,7 @@ def cleanup(self) -> None: """ Closes all cached connections. """ - for connection, _ in self._cache.values(): + for (connection, _) in self._cache.values(): connection.close() self._cache = {} @@ -246,23 +225,21 @@ class SshService(Service, metaclass=ABCMeta): _REQUEST_TIMEOUT: Optional[float] = None # seconds - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__(config, global_config, parent, methods) # Make sure that the value we allow overriding on a per-connection # basis are present in the config so merge_parameters can do its thing. - self.config.setdefault("ssh_port", None) - assert isinstance(self.config["ssh_port"], (int, type(None))) - self.config.setdefault("ssh_username", None) - assert isinstance(self.config["ssh_username"], (str, type(None))) - self.config.setdefault("ssh_priv_key_path", None) - assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) + self.config.setdefault('ssh_port', None) + assert isinstance(self.config['ssh_port'], (int, type(None))) + self.config.setdefault('ssh_username', None) + assert isinstance(self.config['ssh_username'], (str, type(None))) + self.config.setdefault('ssh_priv_key_path', None) + assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) # None can be used to disable the request timeout. self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) @@ -273,24 +250,24 @@ def __init__( # In general scripted commands shouldn't need a pty and having one # available can confuse some commands, though we may need to make # this configurable in the future. - "request_pty": False, + 'request_pty': False, # By default disable known_hosts checking (since most VMs expected to be dynamically created). - "known_hosts": None, + 'known_hosts': None, } - if "ssh_known_hosts_file" in self.config: - self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) - if isinstance(self._connect_params["known_hosts"], str): - known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) + if 'ssh_known_hosts_file' in self.config: + self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) + if isinstance(self._connect_params['known_hosts'], str): + known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) if not os.path.exists(known_hosts_file): raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") - self._connect_params["known_hosts"] = known_hosts_file - if self._connect_params["known_hosts"] is None: + self._connect_params['known_hosts'] = known_hosts_file + if self._connect_params['known_hosts'] is None: _LOG.info("%s known_hosts checking is disabled per config.", self) - if "ssh_keepalive_interval" in self.config: - keepalive_internal = self.config.get("ssh_keepalive_interval") - self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) + if 'ssh_keepalive_interval' in self.config: + keepalive_internal = self.config.get('ssh_keepalive_interval') + self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) def _enter_context(self) -> "SshService": # Start the background thread if it's not already running. @@ -300,12 +277,9 @@ def _enter_context(self) -> "SshService": super()._enter_context() return self - def _exit_context( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def _exit_context(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: # Stop the background thread if it's not needed anymore and potentially # cleanup the cache as well. assert self._in_context @@ -360,26 +334,24 @@ def _get_connect_params(self, params: dict) -> dict: # Start with the base config params. connect_params = self._connect_params.copy() - connect_params["host"] = params["ssh_hostname"] # required + connect_params['host'] = params['ssh_hostname'] # required - if params.get("ssh_port"): - connect_params["port"] = int(params.pop("ssh_port")) - elif self.config["ssh_port"]: - connect_params["port"] = int(self.config["ssh_port"]) + if params.get('ssh_port'): + connect_params['port'] = int(params.pop('ssh_port')) + elif self.config['ssh_port']: + connect_params['port'] = int(self.config['ssh_port']) - if "ssh_username" in params: - connect_params["username"] = str(params.pop("ssh_username")) - elif self.config["ssh_username"]: - connect_params["username"] = str(self.config["ssh_username"]) + if 'ssh_username' in params: + connect_params['username'] = str(params.pop('ssh_username')) + elif self.config['ssh_username']: + connect_params['username'] = str(self.config['ssh_username']) - priv_key_file: Optional[str] = params.get( - "ssh_priv_key_path", self.config["ssh_priv_key_path"] - ) + priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) if not os.path.exists(priv_key_file): raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") - connect_params["client_keys"] = [priv_key_file] + connect_params['client_keys'] = [priv_key_file] return connect_params @@ -398,6 +370,4 @@ async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnectio The connection and client objects. """ assert self._in_context - return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( - self._get_connect_params(params) - ) + return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 02bb06e755..725d0c3306 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -18,12 +18,12 @@ from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec __all__ = [ - "SupportsAuth", - "SupportsConfigLoading", - "SupportsFileShareOps", - "SupportsHostProvisioning", - "SupportsLocalExec", - "SupportsNetworkProvisioning", - "SupportsRemoteConfig", - "SupportsRemoteExec", + 'SupportsAuth', + 'SupportsConfigLoading', + 'SupportsFileShareOps', + 'SupportsHostProvisioning', + 'SupportsLocalExec', + 'SupportsNetworkProvisioning', + 'SupportsRemoteConfig', + 'SupportsRemoteExec', ] diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index b09788476f..05853da0a9 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -34,7 +34,8 @@ class SupportsConfigLoading(Protocol): Protocol interface for helper functions to lookup and load configs. """ - def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, + extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -52,9 +53,7 @@ def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = No An actual path to the config or script. """ - def load_config( - self, json_file_name: str, schema_type: Optional[ConfigSchema] - ) -> Union[dict, List[dict]]: + def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. @@ -73,14 +72,12 @@ def load_config( Free-format dictionary that contains the configuration. """ - def build_environment( - self, # pylint: disable=too-many-arguments - config: dict, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None, - ) -> "Environment": + def build_environment(self, # pylint: disable=too-many-arguments + config: dict, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None) -> "Environment": """ Factory method for a new environment with a given config. @@ -110,13 +107,12 @@ def build_environment( """ def load_environment_list( # pylint: disable=too-many-arguments - self, - json_file_name: str, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None, - ) -> List["Environment"]: + self, + json_file_name: str, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None) -> List["Environment"]: """ Load and build a list of environments from the config file. @@ -141,12 +137,9 @@ def load_environment_list( # pylint: disable=too-many-arguments A list of new benchmarking environments. """ - def load_services( - self, - json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - ) -> "Service": + def load_services(self, json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None) -> "Service": """ Read the configuration files and bundle all service methods from those configs into a single Service object. diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 8252dc17ed..87ec9e49da 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -15,9 +15,7 @@ class SupportsFileShareOps(Protocol): Protocol interface for file share operations. """ - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: """ Downloads contents from a remote share path to a local path. @@ -35,9 +33,7 @@ def download( if True (the default), download the entire directory tree. """ - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 126966c713..c4c5f01ddc 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -32,12 +32,9 @@ class SupportsLocalExec(Protocol): Used in LocalEnv and provided by LocalExecService. """ - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, TunableValue]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -58,9 +55,7 @@ def local_exec( A 3-tuple of return code, stdout, and stderr of the script process. """ - def temp_dir_context( - self, path: Optional[str] = None - ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 50b24cc4b8..fb753aa21c 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -56,9 +56,7 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ - def deprovision_network( - self, params: dict, ignore_errors: bool = True - ) -> Tuple["Status", dict]: + def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index f93de1eab1..c653e10c2b 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -18,7 +18,8 @@ class SupportsRemoteConfig(Protocol): Protocol interface for configuring cloud services. """ - def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]: + def configure(self, config: Dict[str, Any], + params: Dict[str, Any]) -> Tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index f6ca57912a..096cb3c675 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -20,9 +20,8 @@ class SupportsRemoteExec(Protocol): scripts on a remote host OS. """ - def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict - ) -> Tuple["Status", dict]: + def remote_exec(self, script: Iterable[str], config: dict, + env_params: dict) -> Tuple["Status", dict]: """ Run a command on remote host OS. diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index 0812270747..9ae5c80f36 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -10,6 +10,6 @@ from mlos_bench.storage.storage_factory import from_config __all__ = [ - "Storage", - "from_config", + 'Storage', + 'from_config', ] diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index 47581f0725..ce07e44e2b 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -32,15 +32,12 @@ class ExperimentData(metaclass=ABCMeta): RESULT_COLUMN_PREFIX = "result." CONFIG_COLUMN_PREFIX = "config." - def __init__( - self, - *, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str, - ): + def __init__(self, *, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str): self._experiment_id = experiment_id self._description = description self._root_env_config = root_env_config @@ -145,9 +142,9 @@ def default_tunable_config_id(self) -> Optional[int]: trials_items = sorted(self.trials.items()) if not trials_items: return None - for _trial_id, trial in trials_items: + for (_trial_id, trial) in trials_items: # Take the first config id marked as "defaults" when it was instantiated. - if strtobool(str(trial.metadata_dict.get("is_defaults", False))): + if strtobool(str(trial.metadata_dict.get('is_defaults', False))): return trial.tunable_config_id # Fallback (min trial_id) return trials_items[0][1].tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index b7df86a4b7..2165fa706f 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -30,12 +30,10 @@ class Storage(metaclass=ABCMeta): and storage systems (e.g., SQLite or MLFLow). """ - def __init__( - self, - config: Dict[str, Any], - global_config: Optional[dict] = None, - service: Optional[Service] = None, - ): + def __init__(self, + config: Dict[str, Any], + global_config: Optional[dict] = None, + service: Optional[Service] = None): """ Create a new storage object. @@ -76,16 +74,13 @@ def experiments(self) -> Dict[str, ExperimentData]: """ @abstractmethod - def experiment( - self, - *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal["min", "max"]], - ) -> "Storage.Experiment": + def experiment(self, *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment': """ Create a new experiment in the storage. @@ -121,27 +116,23 @@ class Experiment(metaclass=ABCMeta): This class is instantiated in the `Storage.experiment()` method. """ - def __init__( - self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal["min", "max"]], - ): + def __init__(self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal['min', 'max']]): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( - root_env_config - ) + (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) self._description = description self._opt_targets = opt_targets self._in_context = False - def __enter__(self) -> "Storage.Experiment": + def __enter__(self) -> 'Storage.Experiment': """ Enter the context of the experiment. @@ -153,12 +144,9 @@ def __enter__(self) -> "Storage.Experiment": self._in_context = True return self - def __exit__( - self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Literal[False]: """ End the context of the experiment. @@ -169,9 +157,8 @@ def __exit__( _LOG.debug("Finishing experiment: %s", self) else: assert exc_type and exc_val - _LOG.warning( - "Finishing experiment: %s", self, exc_info=(exc_type, exc_val, exc_tb) - ) + _LOG.warning("Finishing experiment: %s", self, + exc_info=(exc_type, exc_val, exc_tb)) assert self._in_context self._teardown(is_ok) self._in_context = False @@ -261,10 +248,8 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: """ @abstractmethod - def load( - self, - last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load(self, last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. @@ -284,9 +269,7 @@ def load( """ @abstractmethod - def pending_trials( - self, timestamp: datetime, *, running: bool - ) -> Iterator["Storage.Trial"]: + def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: """ Return an iterator over the pending trials that are scheduled to run on or before the specified timestamp. @@ -306,12 +289,8 @@ def pending_trials( """ @abstractmethod - def new_trial( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> "Storage.Trial": + def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': """ Create a new experiment run in the storage. @@ -338,16 +317,10 @@ class Trial(metaclass=ABCMeta): This class is instantiated in the `Storage.Experiment.trial()` method. """ - def __init__( - self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - opt_targets: Dict[str, Literal["min", "max"]], - config: Optional[Dict[str, Any]] = None, - ): + def __init__(self, *, + tunables: TunableGroups, experiment_id: str, trial_id: int, + tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']], + config: Optional[Dict[str, Any]] = None): self._tunables = tunables self._experiment_id = experiment_id self._trial_id = trial_id @@ -405,9 +378,9 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An return config @abstractmethod - def update( - self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update(self, status: Status, timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. @@ -431,18 +404,14 @@ def update( assert metrics is not None opt_targets = set(self._opt_targets.keys()) if not opt_targets.issubset(metrics.keys()): - _LOG.warning( - "Trial %s :: opt.targets missing: %s", - self, - opt_targets.difference(metrics.keys()), - ) + _LOG.warning("Trial %s :: opt.targets missing: %s", + self, opt_targets.difference(metrics.keys())) # raise ValueError() return metrics @abstractmethod - def update_telemetry( - self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] - ) -> None: + def update_telemetry(self, status: Status, timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]]) -> None: """ Save the experiment's telemetry data and intermediate status. diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index 6ad397d753..b3b2bed86a 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -31,16 +31,13 @@ class TrialData(metaclass=ABCMeta): of tunable parameters). """ - def __init__( - self, - *, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status, - ): + def __init__(self, *, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status): self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0c9adce22d..0dce110b1b 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -21,7 +21,8 @@ class TunableConfigData(metaclass=ABCMeta): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, tunable_config_id: int): + def __init__(self, *, + tunable_config_id: int): self._tunable_config_id = tunable_config_id def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 6ad0fe185a..18c50035a9 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -27,13 +27,10 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__( - self, - *, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None, - ): + def __init__(self, *, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None): self._experiment_id = experiment_id self._tunable_config_id = tunable_config_id # can be lazily initialized as necessary: @@ -80,10 +77,7 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return ( - self._tunable_config_id == other._tunable_config_id - and self._experiment_id == other._experiment_id - ) + return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index cf09b9aa5a..735e21bcaf 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -8,5 +8,5 @@ from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ - "SqlStorage", + 'SqlStorage', ] diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index bdeb6d8bf3..c7ee73a3bc 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -18,8 +18,10 @@ def get_trials( - engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None -) -> Dict[int, TrialData]: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -28,18 +30,13 @@ def get_trials( from mlos_bench.storage.sql.trial_data import ( TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import ) - with engine.connect() as conn: # Build up sql a statement for fetching trials. - stmt = ( - schema.trial.select() - .where( - schema.trial.c.exp_id == experiment_id, - ) - .order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), - ) + stmt = schema.trial.select().where( + schema.trial.c.exp_id == experiment_id, + ).order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -63,8 +60,10 @@ def get_trials( def get_results_df( - engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None -) -> pandas.DataFrame: + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -73,22 +72,15 @@ def get_results_df( # pylint: disable=too-many-locals with engine.connect() as conn: # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. - tunable_config_group_id_stmt = ( - schema.trial.select() - .with_only_columns( - schema.trial.c.exp_id, - schema.trial.c.config_id, - func.min(schema.trial.c.trial_id) - .cast(Integer) - .label("tunable_config_trial_group_id"), - ) - .where( - schema.trial.c.exp_id == experiment_id, - ) - .group_by( - schema.trial.c.exp_id, - schema.trial.c.config_id, - ) + tunable_config_group_id_stmt = schema.trial.select().with_only_columns( + schema.trial.c.exp_id, + schema.trial.c.config_id, + func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'), + ).where( + schema.trial.c.exp_id == experiment_id, + ).group_by( + schema.trial.c.exp_id, + schema.trial.c.config_id, ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -98,22 +90,18 @@ def get_results_df( tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() # Get each trial's metadata. - cur_trials_stmt = ( - select( - schema.trial, - tunable_config_trial_group_id_subquery, - ) - .where( - schema.trial.c.exp_id == experiment_id, - and_( - tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, - tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, - ), - ) - .order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), - ) + cur_trials_stmt = select( + schema.trial, + tunable_config_trial_group_id_subquery, + ).where( + schema.trial.c.exp_id == experiment_id, + and_( + tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, + tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, + ), + ).order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -122,48 +110,39 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [ - ( - row.trial_id, - utcify_timestamp(row.ts_start, origin="utc"), - utcify_nullable_timestamp(row.ts_end, origin="utc"), - row.config_id, - row.tunable_config_trial_group_id, - row.status, - ) - for row in cur_trials.fetchall() - ], + [( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) for row in cur_trials.fetchall()], columns=[ - "trial_id", - "ts_start", - "ts_end", - "tunable_config_id", - "tunable_config_trial_group_id", - "status", - ], + 'trial_id', + 'ts_start', + 'ts_end', + 'tunable_config_id', + 'tunable_config_trial_group_id', + 'status', + ] ) # Get each trial's config in wide format. - configs_stmt = ( - schema.trial.select() - .with_only_columns( - schema.trial.c.trial_id, - schema.trial.c.config_id, - schema.config_param.c.param_id, - schema.config_param.c.param_value, - ) - .where( - schema.trial.c.exp_id == experiment_id, - ) - .join( - schema.config_param, - schema.config_param.c.config_id == schema.trial.c.config_id, - isouter=True, - ) - .order_by( - schema.trial.c.trial_id, - schema.config_param.c.param_id, - ) + configs_stmt = schema.trial.select().with_only_columns( + schema.trial.c.trial_id, + schema.trial.c.config_id, + schema.config_param.c.param_id, + schema.config_param.c.param_value, + ).where( + schema.trial.c.exp_id == experiment_id, + ).join( + schema.config_param, + schema.config_param.c.config_id == schema.trial.c.config_id, + isouter=True + ).order_by( + schema.trial.c.trial_id, + schema.config_param.c.param_id, ) if tunable_config_id is not None: configs_stmt = configs_stmt.where( @@ -171,67 +150,41 @@ def get_results_df( ) configs = conn.execute(configs_stmt) configs_df = pandas.DataFrame( - [ - ( - row.trial_id, - row.config_id, - ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, - row.param_value, - ) - for row in configs.fetchall() - ], - columns=["trial_id", "tunable_config_id", "param", "value"], + [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value) + for row in configs.fetchall()], + columns=['trial_id', 'tunable_config_id', 'param', 'value'] ).pivot( - index=["trial_id", "tunable_config_id"], - columns="param", - values="value", + index=["trial_id", "tunable_config_id"], columns="param", values="value", ) - configs_df = configs_df.apply(pandas.to_numeric, errors="coerce").fillna(configs_df) # type: ignore[assignment] # (fp) + configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp) # Get each trial's results in wide format. - results_stmt = ( - schema.trial_result.select() - .with_only_columns( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - schema.trial_result.c.metric_value, - ) - .where( - schema.trial_result.c.exp_id == experiment_id, - ) - .order_by( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - ) + results_stmt = schema.trial_result.select().with_only_columns( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + schema.trial_result.c.metric_value, + ).where( + schema.trial_result.c.exp_id == experiment_id, + ).order_by( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, ) if tunable_config_id is not None: - results_stmt = results_stmt.join( - schema.trial, - and_( - schema.trial.c.exp_id == schema.trial_result.c.exp_id, - schema.trial.c.trial_id == schema.trial_result.c.trial_id, - schema.trial.c.config_id == tunable_config_id, - ), - ) + results_stmt = results_stmt.join(schema.trial, and_( + schema.trial.c.exp_id == schema.trial_result.c.exp_id, + schema.trial.c.trial_id == schema.trial_result.c.trial_id, + schema.trial.c.config_id == tunable_config_id, + )) results = conn.execute(results_stmt) results_df = pandas.DataFrame( - [ - ( - row.trial_id, - ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, - row.metric_value, - ) - for row in results.fetchall() - ], - columns=["trial_id", "metric", "value"], + [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value) + for row in results.fetchall()], + columns=['trial_id', 'metric', 'value'] ).pivot( - index="trial_id", - columns="metric", - values="value", + index="trial_id", columns="metric", values="value", ) - results_df = results_df.apply(pandas.to_numeric, errors="coerce").fillna(results_df) # type: ignore[assignment] # (fp) + results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp) # Concat the trials, configs, and results. - return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( - results_df, on="trial_id", how="left" - ) + return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \ + .merge(results_df, on="trial_id", how="left") diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index e6322c7ade..58ee3dddb5 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -29,18 +29,15 @@ class Experiment(Storage.Experiment): Logic for retrieving and storing the results of a single experiment. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal["min", "max"]], - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal['min', 'max']]): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -58,22 +55,18 @@ def _setup(self) -> None: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable exp_info = conn.execute( - self._schema.experiment.select() - .with_only_columns( + self._schema.experiment.select().with_only_columns( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, func.max(self._schema.trial.c.trial_id).label("trial_id"), - ) - .join( + ).join( self._schema.trial, self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, - isouter=True, - ) - .where( + isouter=True + ).where( self._schema.experiment.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, @@ -82,47 +75,33 @@ def _setup(self) -> None: if exp_info is None: _LOG.info("Start new experiment: %s", self._experiment_id) # It's a new experiment: create a record for it in the database. - conn.execute( - self._schema.experiment.insert().values( - exp_id=self._experiment_id, - description=self._description, - git_repo=self._git_repo, - git_commit=self._git_commit, - root_env_config=self._root_env_config, - ) - ) - conn.execute( - self._schema.objectives.insert().values( - [ - { - "exp_id": self._experiment_id, - "optimization_target": opt_target, - "optimization_direction": opt_dir, - } - for (opt_target, opt_dir) in self.opt_targets.items() - ] - ) - ) + conn.execute(self._schema.experiment.insert().values( + exp_id=self._experiment_id, + description=self._description, + git_repo=self._git_repo, + git_commit=self._git_commit, + root_env_config=self._root_env_config, + )) + conn.execute(self._schema.objectives.insert().values([ + { + "exp_id": self._experiment_id, + "optimization_target": opt_target, + "optimization_direction": opt_dir, + } + for (opt_target, opt_dir) in self.opt_targets.items() + ])) else: if exp_info.trial_id is not None: self._trial_id = exp_info.trial_id + 1 - _LOG.info( - "Continue experiment: %s last trial: %s resume from: %d", - self._experiment_id, - exp_info.trial_id, - self._trial_id, - ) + _LOG.info("Continue experiment: %s last trial: %s resume from: %d", + self._experiment_id, exp_info.trial_id, self._trial_id) # TODO: Sanity check that certain critical configs (e.g., # objectives) haven't changed to be incompatible such that a new # experiment should be started (possibly by prewarming with the # previous one). if exp_info.git_commit != self._git_commit: - _LOG.warning( - "Experiment %s git expected: %s %s", - self, - exp_info.git_repo, - exp_info.git_commit, - ) + _LOG.warning("Experiment %s git expected: %s %s", + self, exp_info.git_repo, exp_info.git_commit) def merge(self, experiment_ids: List[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) @@ -135,42 +114,33 @@ def load_tunable_config(self, config_id: int) -> Dict[str, Any]: def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select() - .where( + self._schema.trial_telemetry.select().where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == trial_id, - ) - .order_by( + self._schema.trial_telemetry.c.trial_id == trial_id + ).order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) ) # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. - return [ - (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) - for row in cur_telemetry.fetchall() - ] + return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall()] - def load( - self, - last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load(self, last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: with self._engine.connect() as conn: cur_trials = conn.execute( - self._schema.trial.select() - .with_only_columns( + self._schema.trial.select().with_only_columns( self._schema.trial.c.trial_id, self._schema.trial.c.config_id, self._schema.trial.c.status, - ) - .where( + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, - self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]), - ) - .order_by( + self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), + ).order_by( self._schema.trial.c.trial_id.asc(), ) ) @@ -184,21 +154,12 @@ def load( stat = Status[trial.status] status.append(stat) trial_ids.append(trial.trial_id) - configs.append( - self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id - ) - ) + configs.append(self._get_key_val( + conn, self._schema.config_param, "param", config_id=trial.config_id)) if stat.is_succeeded(): - scores.append( - self._get_key_val( - conn, - self._schema.trial_result, - "metric", - exp_id=self._experiment_id, - trial_id=trial.trial_id, - ) - ) + scores.append(self._get_key_val( + conn, self._schema.trial_result, "metric", + exp_id=self._experiment_id, trial_id=trial.trial_id)) else: scores.append(None) @@ -214,59 +175,49 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D select( column(f"{field}_id"), column(f"{field}_value"), + ).select_from(table).where( + *[column(key) == val for (key, val) in kwargs.items()] ) - .select_from(table) - .where(*[column(key) == val for (key, val) in kwargs.items()]) ) # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. - return dict( - row._tuple() for row in cur_result.fetchall() - ) # pylint: disable=protected-access + return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access @staticmethod - def _save_params( - conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any - ) -> None: + def _save_params(conn: Connection, table: Table, + params: Dict[str, Any], **kwargs: Any) -> None: if not params: return - conn.execute( - table.insert(), - [ - {**kwargs, "param_id": key, "param_value": nullable(str, val)} - for (key, val) in params.items() - ], - ) + conn.execute(table.insert(), [ + { + **kwargs, + "param_id": key, + "param_value": nullable(str, val) + } + for (key, val) in params.items() + ]) def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) if running: - pending_status = ["PENDING", "READY", "RUNNING"] + pending_status = ['PENDING', 'READY', 'RUNNING'] else: - pending_status = ["PENDING"] + pending_status = ['PENDING'] with self._engine.connect() as conn: - cur_trials = conn.execute( - self._schema.trial.select().where( - self._schema.trial.c.exp_id == self._experiment_id, - ( - self._schema.trial.c.ts_start.is_(None) - | (self._schema.trial.c.ts_start <= timestamp) - ), - self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), - ) - ) + cur_trials = conn.execute(self._schema.trial.select().where( + self._schema.trial.c.exp_id == self._experiment_id, + (self._schema.trial.c.ts_start.is_(None) | + (self._schema.trial.c.ts_start <= timestamp)), + self._schema.trial.c.ts_end.is_(None), + self._schema.trial.c.status.in_(pending_status), + )) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id - ) + conn, self._schema.config_param, "param", + config_id=trial.config_id) config = self._get_key_val( - conn, - self._schema.trial_param, - "param", - exp_id=self._experiment_id, - trial_id=trial.trial_id, - ) + conn, self._schema.trial_param, "param", + exp_id=self._experiment_id, trial_id=trial.trial_id) yield Trial( engine=self._engine, schema=self._schema, @@ -284,55 +235,42 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: Get the config ID for the given tunables. If the config does not exist, create a new record for it. """ - config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() - cur_config = conn.execute( - self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) - ).fetchone() + config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() + cur_config = conn.execute(self._schema.config.select().where( + self._schema.config.c.config_hash == config_hash + )).fetchone() if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute( - self._schema.config.insert().values(config_hash=config_hash) - ).inserted_primary_key[0] + config_id: int = conn.execute(self._schema.config.insert().values( + config_hash=config_hash)).inserted_primary_key[0] self._save_params( - conn, - self._schema.config_param, + conn, self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, - config_id=config_id, - ) + config_id=config_id) return config_id - def new_trial( - self, - tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None, - ) -> Storage.Trial: + def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None) -> Storage.Trial: ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) - conn.execute( - self._schema.trial.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - config_id=config_id, - ts_start=ts_start, - status="PENDING", - ) - ) + conn.execute(self._schema.trial.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + config_id=config_id, + ts_start=ts_start, + status='PENDING', + )) # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: self._save_params( - conn, - self._schema.trial_param, - config, - exp_id=self._experiment_id, - trial_id=self._trial_id, - ) + conn, self._schema.trial_param, config, + exp_id=self._experiment_id, trial_id=self._trial_id) trial = Trial( engine=self._engine, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index f299bcff68..eaa6e1041f 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -35,17 +35,14 @@ class ExperimentSqlData(ExperimentData): scripts and mlos_bench configuration files. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str): super().__init__( experiment_id=experiment_id, description=description, @@ -60,11 +57,9 @@ def __init__( def objectives(self) -> Dict[str, Literal["min", "max"]]: with self._engine.connect() as conn: objectives_db_data = conn.execute( - self._schema.objectives.select() - .where( + self._schema.objectives.select().where( self._schema.objectives.c.exp_id == self._experiment_id, - ) - .order_by( + ).order_by( self._schema.objectives.c.weight.desc(), self._schema.objectives.c.optimization_target.asc(), ) @@ -85,17 +80,13 @@ def trials(self) -> Dict[int, TrialData]: def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: with self._engine.connect() as conn: tunable_config_trial_groups = conn.execute( - self._schema.trial.select() - .with_only_columns( + self._schema.trial.select().with_only_columns( self._schema.trial.c.config_id, - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label("tunable_config_trial_group_id"), # pylint: disable=not-callable - ) - .where( + func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + 'tunable_config_trial_group_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -115,14 +106,11 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: def tunable_configs(self) -> Dict[int, TunableConfigData]: with self._engine.connect() as conn: tunable_configs = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -151,28 +139,20 @@ def default_tunable_config_id(self) -> Optional[int]: """ with self._engine.connect() as conn: query_results = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial_param.select() - .with_only_columns( - func.min(self._schema.trial_param.c.trial_id) - .cast(Integer) - .label("first_trial_id_with_defaults"), # pylint: disable=not-callable - ) - .where( + self._schema.trial_param.select().with_only_columns( + func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + "first_trial_id_with_defaults"), + ).where( self._schema.trial_param.c.exp_id == self._experiment_id, self._schema.trial_param.c.param_id == "is_defaults", - func.lower(self._schema.trial_param.c.param_value, type_=String).in_( - ["1", "true"] - ), - ) - .scalar_subquery() - ), + func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), + ).scalar_subquery() + ) ) ) min_default_trial_row = query_results.fetchone() @@ -181,24 +161,17 @@ def default_tunable_config_id(self) -> Optional[int]: return min_default_trial_row._tuple()[0] # fallback logic - assume minimum trial_id for experiment query_results = conn.execute( - self._schema.trial.select() - .with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label("config_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label('config_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial.select() - .with_only_columns( - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label("first_trial_id"), - ) - .where( + self._schema.trial.select().with_only_columns( + func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), + ).where( self._schema.trial.c.exp_id == self._experiment_id, - ) - .scalar_subquery() - ), + ).scalar_subquery() + ) ) ) min_trial_row = query_results.fetchone() diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 65f0e35694..9a1eca2744 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -80,6 +80,7 @@ def __init__(self, engine: Engine): Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), + PrimaryKeyConstraint("exp_id"), ) @@ -94,25 +95,20 @@ def __init__(self, engine: Engine): # Will need to adjust the insert and return values to support this # eventually. Column("weight", Float, nullable=True), + PrimaryKeyConstraint("exp_id", "optimization_target"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ) # A workaround for SQLAlchemy issue with autoincrement in DuckDB: if engine.dialect.name == "duckdb": - seq_config_id = Sequence("seq_config_id") - col_config_id = Column( - "config_id", - Integer, - seq_config_id, - server_default=seq_config_id.next_value(), - nullable=False, - primary_key=True, - ) + seq_config_id = Sequence('seq_config_id') + col_config_id = Column("config_id", Integer, seq_config_id, + server_default=seq_config_id.next_value(), + nullable=False, primary_key=True) else: - col_config_id = Column( - "config_id", Integer, nullable=False, primary_key=True, autoincrement=True - ) + col_config_id = Column("config_id", Integer, nullable=False, + primary_key=True, autoincrement=True) self.config = Table( "config", @@ -131,6 +127,7 @@ def __init__(self, engine: Engine): Column("ts_end", DateTime), # Should match the text IDs of `mlos_bench.environments.Status` enum: Column("status", String(self._STATUS_LEN), nullable=False), + PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -144,6 +141,7 @@ def __init__(self, engine: Engine): Column("config_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), + PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -157,10 +155,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), + PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_status = Table( @@ -170,10 +168,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), + UniqueConstraint("exp_id", "trial_id", "ts"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_result = Table( @@ -183,10 +181,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), + PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) self.trial_telemetry = Table( @@ -197,15 +195,15 @@ def __init__(self, engine: Engine): Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), + UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), - ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] - ), + ForeignKeyConstraint(["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id]), ) _LOG.debug("Schema: %s", self._meta) - def create(self) -> "DbSchema": + def create(self) -> 'DbSchema': """ Create the DB schema. """ diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index dec1385cf2..bde38575bd 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -27,9 +27,10 @@ class SqlStorage(Storage): An implementation of the Storage interface using SQLAlchemy backend. """ - def __init__( - self, config: dict, global_config: Optional[dict] = None, service: Optional[Service] = None - ): + def __init__(self, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) self._log_sql = self._config.pop("log_sql", False) @@ -46,7 +47,7 @@ def __init__( @property def _schema(self) -> DbSchema: """Lazily create schema upon first access.""" - if not hasattr(self, "_db_schema"): + if not hasattr(self, '_db_schema'): self._db_schema = DbSchema(self._engine).create() if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("DDL statements:\n%s", self._schema) @@ -55,16 +56,13 @@ def _schema(self) -> DbSchema: def __repr__(self) -> str: return self._repr - def experiment( - self, - *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal["min", "max"]], - ) -> Storage.Experiment: + def experiment(self, *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment: return Experiment( engine=self._engine, schema=self._schema, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 189cc68ebd..7ac7958845 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -27,18 +27,15 @@ class Trial(Storage.Trial): Store the results of a single run of the experiment in SQL database. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - config_id: int, - opt_targets: Dict[str, Literal["min", "max"]], - config: Optional[Dict[str, Any]] = None, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + config_id: int, + opt_targets: Dict[str, Literal['min', 'max']], + config: Optional[Dict[str, Any]] = None): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -50,9 +47,9 @@ def __init__( self._engine = engine self._schema = schema - def update( - self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update(self, status: Status, timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) @@ -62,16 +59,13 @@ def update( if status.is_completed(): # Final update of the status and ts_end: cur_status = conn.execute( - self._schema.trial.update() - .where( + self._schema.trial.update().where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] - ), - ) - .values( + ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), + ).values( status=status.name, ts_end=timestamp, ) @@ -79,37 +73,29 @@ def update( if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." - + f" ({cur_status.rowcount} rows)" - ) + f"Failed to update the status of the trial {self} to {status}." + + f" ({cur_status.rowcount} rows)") if metrics: - conn.execute( - self._schema.trial_result.insert().values( - [ - { - "exp_id": self._experiment_id, - "trial_id": self._trial_id, - "metric_id": key, - "metric_value": nullable(str, val), - } - for (key, val) in metrics.items() - ] - ) - ) + conn.execute(self._schema.trial_result.insert().values([ + { + "exp_id": self._experiment_id, + "trial_id": self._trial_id, + "metric_id": key, + "metric_value": nullable(str, val), + } + for (key, val) in metrics.items() + ])) else: # Update of the status and ts_start when starting the trial: assert metrics is None, f"Unexpected metrics for status: {status}" cur_status = conn.execute( - self._schema.trial.update() - .where( + self._schema.trial.update().where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] - ), - ) - .values( + ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), + ).values( status=status.name, ts_start=timestamp, ) @@ -122,9 +108,8 @@ def update( raise return metrics - def update_telemetry( - self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] - ) -> None: + def update_telemetry(self, status: Status, timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]]) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") @@ -135,18 +120,16 @@ def update_telemetry( # See Also: comments in with self._engine.begin() as conn: self._update_status(conn, status, timestamp) - for metric_ts, key, val in metrics: + for (metric_ts, key, val) in metrics: with self._engine.begin() as conn: try: - conn.execute( - self._schema.trial_telemetry.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=metric_ts, - metric_id=key, - metric_value=nullable(str, val), - ) - ) + conn.execute(self._schema.trial_telemetry.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=metric_ts, + metric_id=key, + metric_value=nullable(str, val), + )) except IntegrityError as ex: _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) @@ -158,15 +141,12 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") try: - conn.execute( - self._schema.trial_status.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=timestamp, - status=status.name, - ) - ) + conn.execute(self._schema.trial_status.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=timestamp, + status=status.name, + )) except IntegrityError as ex: - _LOG.warning( - "Status with that timestamp already exists: %s %s :: %s", self, timestamp, ex - ) + _LOG.warning("Status with that timestamp already exists: %s %s :: %s", + self, timestamp, ex) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index c5138f91af..5a6f8a5ee8 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -29,18 +29,15 @@ class TrialSqlData(TrialData): An interface to access the trial data stored in the SQL DB. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - trial_id: int, - config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + trial_id: int, + config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status): super().__init__( experiment_id=experiment_id, trial_id=trial_id, @@ -59,9 +56,8 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ - return TunableConfigSqlData( - engine=self._engine, schema=self._schema, tunable_config_id=self._tunable_config_id - ) + return TunableConfigSqlData(engine=self._engine, schema=self._schema, + tunable_config_id=self._tunable_config_id) @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @@ -72,13 +68,9 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": from mlos_bench.storage.sql.tunable_config_trial_group_data import ( TunableConfigTrialGroupSqlData, ) - - return TunableConfigTrialGroupSqlData( - engine=self._engine, - schema=self._schema, - experiment_id=self._experiment_id, - tunable_config_id=self._tunable_config_id, - ) + return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, + experiment_id=self._experiment_id, + tunable_config_id=self._tunable_config_id) @property def results_df(self) -> pandas.DataFrame: @@ -87,19 +79,16 @@ def results_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_results = conn.execute( - self._schema.trial_result.select() - .where( + self._schema.trial_result.select().where( self._schema.trial_result.c.exp_id == self._experiment_id, - self._schema.trial_result.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_result.c.trial_id == self._trial_id + ).order_by( self._schema.trial_result.c.metric_id, ) ) return pandas.DataFrame( [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], - columns=["metric", "value"], - ) + columns=['metric', 'value']) @property def telemetry_df(self) -> pandas.DataFrame: @@ -108,12 +97,10 @@ def telemetry_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select() - .where( + self._schema.trial_telemetry.select().where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_telemetry.c.trial_id == self._trial_id + ).order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) @@ -121,12 +108,8 @@ def telemetry_df(self) -> pandas.DataFrame: # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [ - (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) - for row in cur_telemetry.fetchall() - ], - columns=["ts", "metric", "value"], - ) + [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], + columns=['ts', 'metric', 'value']) @property def metadata_df(self) -> pandas.DataFrame: @@ -137,16 +120,13 @@ def metadata_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_params = conn.execute( - self._schema.trial_param.select() - .where( + self._schema.trial_param.select().where( self._schema.trial_param.c.exp_id == self._experiment_id, - self._schema.trial_param.c.trial_id == self._trial_id, - ) - .order_by( + self._schema.trial_param.c.trial_id == self._trial_id + ).order_by( self._schema.trial_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_params.fetchall()], - columns=["parameter", "value"], - ) + columns=['parameter', 'value']) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index 2441f70b9c..e484979790 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -20,7 +20,10 @@ class TunableConfigSqlData(TunableConfigData): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + tunable_config_id: int): super().__init__(tunable_config_id=tunable_config_id) self._engine = engine self._schema = schema @@ -29,13 +32,12 @@ def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): def config_df(self) -> pandas.DataFrame: with self._engine.connect() as conn: cur_config = conn.execute( - self._schema.config_param.select() - .where(self._schema.config_param.c.config_id == self._tunable_config_id) - .order_by( + self._schema.config_param.select().where( + self._schema.config_param.c.config_id == self._tunable_config_id + ).order_by( self._schema.config_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_config.fetchall()], - columns=["parameter", "value"], - ) + columns=['parameter', 'value']) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index 3520e77c60..eb389a5940 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -33,15 +33,12 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__( - self, - *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None, - ): + def __init__(self, *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None): super().__init__( experiment_id=experiment_id, tunable_config_id=tunable_config_id, @@ -56,26 +53,20 @@ def _get_tunable_config_trial_group_id(self) -> int: """ with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( - self._schema.trial.select() - .with_only_columns( - func.min(self._schema.trial.c.trial_id) - .cast(Integer) - .label("tunable_config_trial_group_id"), # pylint: disable=not-callable - ) - .where( + self._schema.trial.select().with_only_columns( + func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable + 'tunable_config_trial_group_id'), + ).where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.config_id == self._tunable_config_id, - ) - .group_by( + ).group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[ - 0 - ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy @property def tunable_config(self) -> TunableConfigData: @@ -95,12 +86,8 @@ def trials(self) -> Dict[int, "TrialData"]: trials : Dict[int, TrialData] A dictionary of the trials' data, keyed by trial id. """ - return common.get_trials( - self._engine, self._schema, self._experiment_id, self._tunable_config_id - ) + return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) @property def results_df(self) -> pandas.DataFrame: - return common.get_results_df( - self._engine, self._schema, self._experiment_id, self._tunable_config_id - ) + return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id) diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 22e629fc82..220f3d812c 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -13,9 +13,9 @@ from mlos_bench.storage.base_storage import Storage -def from_config( - config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any -) -> Storage: +def from_config(config_file: str, + global_configs: Optional[List[str]] = None, + **kwargs: Any) -> Storage: """ Create a new storage object from JSON5 config file. @@ -36,7 +36,7 @@ def from_config( config_path: List[str] = kwargs.get("config_path", []) config_loader = ConfigPersistenceService({"config_path": config_path}) global_config = {} - for fname in global_configs or []: + for fname in (global_configs or []): config = config_loader.load_config(fname, ConfigSchema.GLOBALS) global_config.update(config) config_path += config.get("config_path", []) diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index d16dc81b79..a4610da8de 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -25,18 +25,16 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where 'parameter' is a string and 'value' is some TunableValue or None. """ - if dataframe.columns.tolist() == ["metric", "value"]: + if dataframe.columns.tolist() == ['metric', 'value']: dataframe = dataframe.copy() - dataframe.rename(columns={"metric": "parameter"}, inplace=True) - assert dataframe.columns.tolist() == ["parameter", "value"] + dataframe.rename(columns={'metric': 'parameter'}, inplace=True) + assert dataframe.columns.tolist() == ['parameter', 'value'] data = {} - for _, row in dataframe.astype("O").iterrows(): - if not isinstance(row["value"], TunableValueTypeTuple): + for _, row in dataframe.astype('O').iterrows(): + if not isinstance(row['value'], TunableValueTypeTuple): raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}") - assert isinstance(row["parameter"], str) - if row["parameter"] in data: + assert isinstance(row['parameter'], str) + if row['parameter'] in data: raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe") - data[row["parameter"]] = ( - try_parse_val(row["value"]) if isinstance(row["value"], str) else row["value"] - ) + data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value'] return data diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 3b8c23a70c..26aa142441 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -29,34 +29,26 @@ None, ] ZONE_INFO: List[Optional[tzinfo]] = [ - nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES + nullable(pytz.timezone, zone_name) + for zone_name in ZONE_NAMES ] # A decorator for tests that require docker. # Use with @requires_docker above a test_...() function. -DOCKER = shutil.which("docker") +DOCKER = shutil.which('docker') if DOCKER: - cmd = run( - "docker builder inspect default || docker buildx inspect default", - shell=True, - check=False, - capture_output=True, - ) + cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) stdout = cmd.stdout.decode() - if cmd.returncode != 0 or not any( - line for line in stdout.splitlines() if "Platform" in line and "linux" in line - ): + if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): debug("Docker is available but missing support for targeting linux platform.") DOCKER = None -requires_docker = pytest.mark.skipif( - not DOCKER, reason="Docker with Linux support is not available on this system." -) +requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') # A decorator for tests that require ssh. # Use with @requires_ssh above a test_...() function. -SSH = shutil.which("ssh") -requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") +SSH = shutil.which('ssh') +requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -139,14 +131,8 @@ def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) - if ( - len(dirs_cmp.left_only) > 0 - or len(dirs_cmp.right_only) > 0 - or len(dirs_cmp.funny_files) > 0 - ): - warning( - f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" - ) + if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: + warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") return False (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) if len(mismatch) > 0 or len(errors) > 0: diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 61fb063a52..4d728b4037 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -21,11 +21,9 @@ BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/") -def locate_config_examples( - root_dir: str, - config_examples_dir: str, - examples_filter: Optional[Callable[[List[str]], List[str]]] = None, -) -> List[str]: +def locate_config_examples(root_dir: str, + config_examples_dir: str, + examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: """Locates all config examples in the given directory. Parameters diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 7c1d55ef9f..e1e26d7d8b 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -43,9 +43,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - *locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs - ), + *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -53,9 +51,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: # pragma: no cover +def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -65,7 +61,7 @@ def test_load_cli_config_examples( assert isinstance(config_paths, list) config_paths.reverse() for path in config_paths: - config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access + config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access # Foreach arg that references another file, see if we can at least load that too. args_to_skip = { @@ -102,9 +98,7 @@ def test_load_cli_config_examples( @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples_via_launcher( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) @@ -112,12 +106,10 @@ def test_load_cli_config_examples_via_launcher( # Try to load the CLI config by instantiating a launcher. # To do this we need to make sure to give it a few extra paths and globals # to look for for our examples. - cli_args = ( - f"--config {config_path}" - + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" - + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" - + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" - ) + cli_args = f"--config {config_path}" + \ + f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \ + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \ + f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split()) assert launcher @@ -128,16 +120,15 @@ def test_load_cli_config_examples_via_launcher( assert isinstance(config_paths, list) for path in config_paths: # Note: Checks that the order is maintained are handled in launcher_parse_args.py - assert any( - config_path.endswith(path) for config_path in launcher.config_loader.config_paths - ), f"Expected {path} to be in {launcher.config_loader.config_paths}" + assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \ + f"Expected {path} to be in {launcher.config_loader.config_paths}" - if "experiment_id" in config: - assert launcher.global_config["experiment_id"] == config["experiment_id"] - if "trial_id" in config: - assert launcher.global_config["trial_id"] == config["trial_id"] + if 'experiment_id' in config: + assert launcher.global_config['experiment_id'] == config['experiment_id'] + if 'trial_id' in config: + assert launcher.global_config['trial_id'] == config['trial_id'] - expected_log_level = logging.getLevelName(config.get("log_level", "INFO")) + expected_log_level = logging.getLevelName(config.get('log_level', "INFO")) if isinstance(expected_log_level, int): expected_log_level = logging.getLevelName(expected_log_level) current_log_level = logging.getLevelName(logging.root.getEffectiveLevel()) @@ -145,7 +136,7 @@ def test_load_cli_config_examples_via_launcher( # TODO: Check that the log_file handler is set correctly. - expected_teardown = config.get("teardown", True) + expected_teardown = config.get('teardown', True) assert launcher.teardown == expected_teardown # Note: Testing of "globals" processing handled in launcher_parse_args_test.py @@ -154,30 +145,22 @@ def test_load_cli_config_examples_via_launcher( # Launcher loaded the expected types as well. assert isinstance(launcher.environment, Environment) - env_config = launcher.config_loader.load_config( - config["environment"], ConfigSchema.ENVIRONMENT - ) + env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: - opt_config = launcher.config_loader.load_config( - config["optimizer"], ConfigSchema.OPTIMIZER - ) + opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: - storage_config = launcher.config_loader.load_config( - config["storage"], ConfigSchema.STORAGE - ) + storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: - scheduler_config = launcher.config_loader.load_config( - config["scheduler"], ConfigSchema.SCHEDULER - ) + scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER) assert check_class_name(launcher.scheduler, scheduler_config["class"]) # TODO: Check that the launcher assigns the tunables values as expected. diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index 2c3932a128..fdcb3370cf 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -22,11 +22,9 @@ @pytest.fixture def config_loader_service() -> ConfigPersistenceService: """Config loader service fixture.""" - return ConfigPersistenceService( - config={ - "config_path": [ - str(files("mlos_bench.tests.config")), - path_join(str(files("mlos_bench.tests.config")), "globals"), - ] - } - ) + return ConfigPersistenceService(config={ + "config_path": [ + str(files("mlos_bench.tests.config")), + path_join(str(files("mlos_bench.tests.config")), "globals"), + ] + }) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 2369b0c27a..42925a0a5d 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -27,24 +27,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - configs_to_filter = [ - config_path - for config_path in configs_to_filter - if not config_path.endswith("-tunables.jsonc") - ] + configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) for env in envs: @@ -52,15 +44,11 @@ def test_load_environment_config_examples( assert isinstance(env, Environment) -def load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> List[Environment]: +def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. - global_config = config_loader_service.load_config( - "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS - ) - global_config.setdefault("trial_id", 1) # normally populated by Launcher + global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) + global_config.setdefault('trial_id', 1) # normally populated by Launcher # Make sure we have the required services for the envs being used. mock_service_configs = [ @@ -72,34 +60,24 @@ def load_environment_config_examples( "services/remote/mock/mock_auth_service.jsonc", ] - tunable_groups = TunableGroups() # base tunable groups that all others get built on + tunable_groups = TunableGroups() # base tunable groups that all others get built on for mock_service_config_path in mock_service_configs: - mock_service_config = config_loader_service.load_config( - mock_service_config_path, ConfigSchema.SERVICE - ) - config_loader_service.register( - config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service - ).export() - ) + mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) + config_loader_service.register(config_loader_service.build_service( + config=mock_service_config, parent=config_loader_service).export()) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service - ) + config_path, tunable_groups, global_config, service=config_loader_service) return envs -composite_configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" -) +composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) -def test_load_composite_env_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) assert len(envs) == 1 @@ -112,15 +90,11 @@ def test_load_composite_env_config_examples( assert child_env.tunable_params is not None checked_child_env_groups = set() - for child_tunable, child_group in child_env.tunable_params: + for (child_tunable, child_group) in child_env.tunable_params: # Lookup that tunable in the composite env. assert child_tunable in composite_env.tunable_params - (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( - child_tunable - ) - assert ( - child_tunable is composite_tunable - ) # Check that the tunables are the same object. + (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) + assert child_tunable is composite_tunable # Check that the tunables are the same object. if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index fd53d63788..4d8c93fdff 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -29,9 +29,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs - ), + *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs), ] @@ -39,9 +37,7 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) -def test_load_globals_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index c504a6d50f..6cb6253dea 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -30,16 +30,12 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_optimizer_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index 6d2cabaa8a..e4264003e1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -34,17 +34,14 @@ def __hash__(self) -> int: # The different type of schema test cases we expect to have. -_SCHEMA_TEST_TYPES = { - x.test_case_type: x - for x in ( - SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), - SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}), - ) -} +_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( + SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), + SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), +)} @dataclass -class SchemaTestCaseInfo: +class SchemaTestCaseInfo(): """ Some basic info about a schema test case. """ @@ -64,17 +61,15 @@ def check_schema_dir_layout(test_cases_root: str) -> None: any extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): - if test_case_dir == "README.md": + if test_case_dir == 'README.md': continue if test_case_dir not in _SCHEMA_TEST_TYPES: raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): - if test_case_subdir == "README.md": + if test_case_subdir == 'README.md': continue if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: - raise NotImplementedError( - f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}" - ) + raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") @dataclass @@ -92,21 +87,15 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: """ Gets a dict of schema test cases from the given root. """ - test_cases = TestCases( - by_path={}, - by_type={x: {} for x in _SCHEMA_TEST_TYPES}, - by_subtype={ - y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes - }, - ) + test_cases = TestCases(by_path={}, + by_type={x: {} for x in _SCHEMA_TEST_TYPES}, + by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) check_schema_dir_layout(test_cases_root) # Note: we sort the test cases so that we can deterministically test them in parallel. - for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): + for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): for test_case_subtype in schema_test_type.test_case_subtypes: - for test_case_file in locate_config_examples( - test_cases_root, os.path.join(test_case_type, test_case_subtype) - ): - with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: + for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): + with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: try: test_case_info = SchemaTestCaseInfo( config=json5.load(test_case_fh), @@ -115,12 +104,8 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: test_case_subtype=test_case_subtype, ) test_cases.by_path[test_case_info.test_case_file] = test_case_info - test_cases.by_type[test_case_info.test_case_type][ - test_case_info.test_case_file - ] = test_case_info - test_cases.by_subtype[test_case_info.test_case_subtype][ - test_case_info.test_case_file - ] = test_case_info + test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info + test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info except Exception as ex: raise RuntimeError("Failed to load test case: " + test_case_file) from ex assert test_cases @@ -132,9 +117,7 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: return test_cases -def check_test_case_against_schema( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema -) -> None: +def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: """ Checks the given test case against the given schema. @@ -159,9 +142,7 @@ def check_test_case_against_schema( raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") -def check_test_case_config_with_extra_param( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema -) -> None: +def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: """ Checks that the config fails to validate if extra params are present in certain places. """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 32ea0b9713..5dd1666008 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -26,7 +26,6 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: """ @@ -45,9 +44,7 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the cli config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 1528d8d164..dc3cd40425 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -33,21 +33,17 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. } -expected_environment_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") - if subclass not in NON_CONFIG_ENV_CLASSES -] +expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass + in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') + if subclass not in NON_CONFIG_ENV_CLASSES] assert expected_environment_class_names COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ -expected_leaf_environment_class_names = [ - subclass_name - for subclass_name in expected_environment_class_names - if subclass_name != COMPOSITE_ENV_CLASS_NAME -] +expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names + if subclass_name != COMPOSITE_ENV_CLASS_NAME] # Do the full cross product of all the test cases and all the Environment types. @@ -61,13 +57,11 @@ def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_c if try_resolve_class_name(test_case.config.get("class")) == env_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" - ) + f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: """ @@ -82,9 +76,5 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the environment config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 508787a84b..5045bf510b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -25,7 +25,6 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index ef5c0edfa3..e9ee653644 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -33,12 +33,9 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_optimizer_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_optimizer_class_names # Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). @@ -53,9 +50,7 @@ # Do the full cross product of all the test cases and all the optimizer types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) -def test_case_coverage_mlos_bench_optimizer_type( - test_case_subtype: str, mlos_bench_optimizer_type: str -) -> None: +def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench optimizer type. """ @@ -63,9 +58,7 @@ def test_case_coverage_mlos_bench_optimizer_type( if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}" - ) - + f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}") # Being a little lazy for the moment and relaxing the requirement that we have # a subtype test case for each optimizer and space adapter combo. @@ -74,58 +67,47 @@ def test_case_coverage_mlos_bench_optimizer_type( @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) -def test_case_coverage_mlos_core_optimizer_type( - test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] -) -> None: +def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, + mlos_core_optimizer_type: Optional[OptimizerType]) -> None: """ Checks to see if there is a given type of test case for the given mlos_core optimizer type. """ optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" - ): + if try_resolve_class_name(test_case.config.get("class")) \ + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": optimizer_type = None if test_case.config.get("config"): optimizer_type = test_case.config["config"].get("optimizer_type", None) if optimizer_type == optimizer_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}" - ) + f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}") @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) -def test_case_coverage_mlos_core_space_adapter_type( - test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] -) -> None: +def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, + mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: """ Checks to see if there is a given type of test case for the given mlos_core space adapter type. """ - space_adapter_name = ( - None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name - ) + space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if ( - try_resolve_class_name(test_case.config.get("class")) - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" - ): + if try_resolve_class_name(test_case.config.get("class")) \ + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": space_adapter_type = None if test_case.config.get("config"): space_adapter_type = test_case.config["config"].get("space_adapter_type", None) if space_adapter_type == space_adapter_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}" - ) + f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: """ @@ -140,9 +122,5 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the optimizer config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 23bd17b1e7..8fccba8bc7 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -30,12 +30,9 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_scheduler_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_scheduler_class_names # Do the full cross product of all the test cases and all the scheduler types. @@ -43,9 +40,7 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) -def test_case_coverage_mlos_bench_scheduler_type( - test_case_subtype: str, mlos_bench_scheduler_type: str -) -> None: +def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench scheduler type. """ @@ -53,9 +48,7 @@ def test_case_coverage_mlos_bench_scheduler_type( if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}" - ) - + f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}") # Now we actually perform all of those validation tests. @@ -74,12 +67,8 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the scheduler config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 032b4c0aad..64c6fccccd 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -38,17 +38,16 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_SERVICE_CLASSES = { - ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. - AzureDeploymentService, # ABCMeta abstract base class - SshService, # ABCMeta abstract base class + ConfigPersistenceService, # configured thru the launcher cli args + TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. + AzureDeploymentService, # ABCMeta abstract base class + SshService, # ABCMeta abstract base class } -expected_service_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") - if subclass not in NON_CONFIG_SERVICE_CLASSES -] +expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass + in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') + if subclass not in NON_CONFIG_SERVICE_CLASSES] assert expected_service_class_names @@ -62,7 +61,7 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] if not isinstance(test_case.config, dict): - continue # type: ignore[unreachable] + continue # type: ignore[unreachable] if "class" not in test_case.config: config_list = test_case.config["services"] else: @@ -71,13 +70,11 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c if try_resolve_class_name(config.get("class")) == service_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for service class {service_class}" - ) + f"Missing test case for subtype {test_case_subtype} for service class {service_class}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: """ @@ -92,9 +89,5 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the service config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index fd2de83cd0..9b362b5e0d 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -28,12 +28,9 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_storage_class_names = [ - subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses( - Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] - ) -] +expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract] + pkg_name='mlos_bench')] assert expected_mlos_bench_storage_class_names # Do the full cross product of all the test cases and all the storage types. @@ -41,9 +38,7 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) -def test_case_coverage_mlos_bench_storage_type( - test_case_subtype: str, mlos_bench_storage_type: str -) -> None: +def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: """ Checks to see if there is a given type of test case for the given mlos_bench storage type. """ @@ -51,13 +46,11 @@ def test_case_coverage_mlos_bench_storage_type( if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}" - ) + f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}") # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: """ @@ -72,15 +65,9 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: """ Checks that the storage config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE - ) - check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED - ) - - -if __name__ == "__main__": - pytest.main( - [__file__, "-n0"], - ) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) + check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + + +if __name__ == '__main__': + pytest.main([__file__, '-n0'],) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index 11849119c3..a6d0de9313 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -25,7 +25,6 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 33124134e9..d871eaa212 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -25,7 +25,6 @@ # Now we actually perform all of those validation tests. - @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 8431251098..32034eb11c 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -25,27 +25,19 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - def predicate(config_path: str) -> bool: - arm_template = config_path.find( - "services/remote/azure/arm-templates/" - ) >= 0 and config_path.endswith(".jsonc") + arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc") setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0 return not (arm_template or setup_rg_scripts) - return [config_path for config_path in configs_to_filter if predicate(config_path)] -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_service_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) # Make an instance of the class based on the config. diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index d1d39ec4f5..2f9773a9b0 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -29,16 +29,12 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs -) +configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_storage_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str -) -> None: +def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 304d4903b3..58359eb983 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -42,7 +42,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score"], }, - tunables=tunable_groups, + tunables=tunable_groups ) @@ -59,7 +59,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score", "other_score"], }, - tunables=tunable_groups, + tunables=tunable_groups ) @@ -103,9 +103,7 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") -def docker_services_lock( - shared_temp_dir: str, short_testrun_uid: str -) -> InterProcessReaderWriterLock: +def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker services. @@ -115,9 +113,7 @@ def docker_services_lock( A lock to ensure that setup/teardown operations don't happen while a worker is using the docker services. """ - return InterProcessReaderWriterLock( - f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock" - ) + return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock") @pytest.fixture(scope="session") @@ -130,9 +126,7 @@ def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> ------ A lock to ensure that only one worker is doing setup/teardown at a time. """ - return InterProcessLock( - f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock" - ) + return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock") @pytest.fixture(scope="session") diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index 8218577986..ac0b942167 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -16,13 +16,11 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def check_env_success( - env: Environment, - tunable_groups: TunableGroups, - expected_results: Dict[str, TunableValue], - expected_telemetry: List[Tuple[datetime, str, Any]], - global_config: Optional[dict] = None, -) -> None: +def check_env_success(env: Environment, + tunable_groups: TunableGroups, + expected_results: Dict[str, TunableValue], + expected_telemetry: List[Tuple[datetime, str, Any]], + global_config: Optional[dict] = None) -> None: """ Set up an environment and run a test experiment there. @@ -52,7 +50,7 @@ def check_env_success( assert telemetry == pytest.approx(expected_telemetry, nan_ok=True) env_context.teardown() - assert not env_context._is_ready # pylint: disable=protected-access + assert not env_context._is_ready # pylint: disable=protected-access def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 7be966d482..8afb8e5cda 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -28,13 +28,9 @@ def test_expand_groups() -> None: """ Check the dollar variable expansion for tunable groups. """ - assert Environment._expand_groups(["begin", "$list", "$empty", "$str", "end"], _GROUPS) == [ - "begin", - "c", - "d", - "efg", - "end", - ] + assert Environment._expand_groups( + ["begin", "$list", "$empty", "$str", "end"], + _GROUPS) == ["begin", "c", "d", "efg", "end"] def test_expand_groups_empty_input() -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index f7e0e86795..6497eb6985 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -40,20 +40,20 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "name": "Env 3 :: tmp_other_3", "class": "mlos_bench.environments.mock_env.MockEnv", "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], - }, + } ] }, tunables=tunable_groups, service=LocalExecService( - config={"temp_dir": "_test_tmp_global"}, - parent=ConfigPersistenceService( - { - "config_path": [ - path_join(os.path.dirname(__file__), "../config", abs_path=True), - ] - } - ), - ), + config={ + "temp_dir": "_test_tmp_global" + }, + parent=ConfigPersistenceService({ + "config_path": [ + path_join(os.path.dirname(__file__), "../config", abs_path=True), + ] + }) + ) ) @@ -61,7 +61,7 @@ def test_composite_services(composite_env: CompositeEnv) -> None: """ Check that each environment gets its own instance of the services. """ - for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): + for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") with service.temp_dir_context() as temp_dir: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 1a159ef4ef..742eaf3c79 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -28,7 +28,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", "someConst": "root", - "global_param": "default", + "global_param": "default" }, "children": [ { @@ -43,7 +43,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "someConst", "global_param"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, { "name": "Mock Server Environment 2", @@ -53,12 +53,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vmName": "$vm_server_name", "EnvId": 2, - "global_param": "local", + "global_param": "local" }, "required_args": ["vmName"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, { "name": "Mock Control Environment 3", @@ -72,13 +72,15 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "vm_server_name", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, - }, - ], + } + } + ] }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={"global_param": "global_value"}, + global_config={ + "global_param": "global_value" + } ) @@ -88,26 +90,26 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value", # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", - "vm_server_name": "Mock Server VM", + "vm_server_name": "Mock Server VM" # "global_param": "global_value" # not required, so not picked from the global_config } @@ -116,35 +118,33 @@ def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: Tunabl """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign( - { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + }) with composite_env as env_context: assert env_context.setup(tunable_groups) assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value", # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "mwait", # tunable_params from the parent "vm_client_name": "Mock Client VM", "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config @@ -163,7 +163,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", - "someConst": "root", + "someConst": "root" }, "children": [ { @@ -191,11 +191,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "EnvId", "someConst", "vm_server_name", - "global_param", + "global_param" ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, # ... ], @@ -220,17 +220,20 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "EnvId", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - }, + } }, # ... ], }, }, - ], + + ] }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={"global_param": "global_value"}, + global_config={ + "global_param": "global_value" + } ) @@ -241,56 +244,52 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None """ assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_nested_composite_env_setup( - nested_composite_env: CompositeEnv, tunable_groups: TunableGroups -) -> None: +def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: """ Check that the child environments update their tunable parameters. """ - tunable_groups.assign( - { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + }) with nested_composite_env as env_context: assert env_context.setup(tunable_groups) assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value", # pulled in from the global_config + "global_param": "global_value" # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", } diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index cbfd6d75ed..7395aa3e15 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -16,7 +16,9 @@ def test_one_group(tunable_groups: TunableGroups) -> None: Make sure only one tunable group is available to the environment. """ env = MockEnv( - name="Test Env", config={"tunable_params": ["provision"]}, tunables=tunable_groups + name="Test Env", + config={"tunable_params": ["provision"]}, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -30,7 +32,7 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, - tunables=tunable_groups, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -53,7 +55,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups, + tunables=tunable_groups ) expected_params = { "vmSize": "Standard_B4ms", @@ -78,7 +80,11 @@ def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: """ Make sure that no tunable groups are available to the environment by default. """ - env = MockEnv(name="Test Env", config={}, tunables=tunable_groups) + env = MockEnv( + name="Test Env", + config={}, + tunables=tunable_groups + ) assert env.tunable_params.get_param_values() == {} @@ -87,7 +93,11 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: Make sure that no tunable groups are available to the environment when explicitly specifying an empty list of tunable_params. """ - env = MockEnv(name="Test Env", config={"tunable_params": []}, tunables=tunable_groups) + env = MockEnv( + name="Test Env", + config={"tunable_params": []}, + tunables=tunable_groups + ) assert env.tunable_params.get_param_values() == {} @@ -104,7 +114,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups, + tunables=tunable_groups ) assert env.tunable_params.get_param_values() == {} @@ -127,7 +137,9 @@ def test_loader_level_include() -> None: env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", "name": "Test Env", - "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"], + "include_tunables": [ + "environments/os/linux/boot/linux-boot-tunables.jsonc" + ], "config": { "tunable_params": ["linux-kernel-boot"], "const_args": { @@ -136,14 +148,12 @@ def test_loader_level_include() -> None: }, }, } - loader = ConfigPersistenceService( - { - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - } - ) + loader = ConfigPersistenceService({ + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + }) env = loader.build_environment(config=env_json, tunables=TunableGroups()) expected_params = { "align_va_addr": "on", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index c68d2fa7b8..5d8fc32c6b 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -32,20 +32,14 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L env : LocalEnv A new instance of the local environment. """ - return LocalEnv( - name="TestLocalEnv", - config=config, - tunables=tunable_groups, - service=LocalExecService(parent=ConfigPersistenceService()), - ) + return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups, + service=LocalExecService(parent=ConfigPersistenceService())) -def create_composite_local_env( - tunable_groups: TunableGroups, - global_config: Dict[str, Any], - params: Dict[str, Any], - local_configs: List[Dict[str, Any]], -) -> CompositeEnv: +def create_composite_local_env(tunable_groups: TunableGroups, + global_config: Dict[str, Any], + params: Dict[str, Any], + local_configs: List[Dict[str, Any]]) -> CompositeEnv: """ Create a CompositeEnv with several LocalEnv instances. @@ -76,7 +70,7 @@ def create_composite_local_env( "config": config, } for (i, config) in enumerate(local_configs) - ], + ] }, tunables=tunable_groups, global_config=global_config, diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 83dcc3ce5d..9bcb7aa218 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -43,7 +43,7 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "") + (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "") env = create_composite_local_env( tunable_groups=tunable_groups, @@ -67,8 +67,8 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo "required_args": ["errors", "reads"], "shell_env_params": [ "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads", # const_args overridden by the global config + "errors", # Comes from the parent const_args + "reads" # const_args overridden by the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -90,9 +90,9 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes", # Comes straight from the global config + "throughput", # const_args overridden by the composite env + "score", # Comes from the local const_args + "writes" # Comes straight from the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -106,13 +106,12 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo ], "read_results_file": "output.csv", "read_telemetry_file": "telemetry.csv", - }, - ], + } + ] ) check_env_success( - env, - tunable_groups, + env, tunable_groups, expected_results={ "latency": 4.2, "throughput": 768.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index bdcd9f885f..20854b9f9e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -17,23 +17,19 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", - ], - "results_stdout_pattern": r"(\w+),([0-9.]+)", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", + ], + "results_stdout_pattern": r"(\w+),([0-9.]+)", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -49,23 +45,19 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: """ Print benchmark results to stdout and capture them in the LocalEnv. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern - ], - "results_stdout_pattern": r"^(\w+),([0-9.]+)$", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern + ], + "results_stdout_pattern": r"^(\w+),([0-9.]+)$", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -80,28 +72,24 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: """ Print benchmark results to *BOTH* stdout and a file and extract the results from both. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'stdout-msg,string'", - "echo '-------------------'", # Should be ignored - "echo 'metric,value' > output.csv", - "echo 'extra1,333' >> output.csv", - "echo 'extra2,444' >> output.csv", - "echo 'file-msg,string' >> output.csv", - ], - "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'stdout-msg,string'", + "echo '-------------------'", # Should be ignored + "echo 'metric,value' > output.csv", + "echo 'extra1,333' >> output.csv", + "echo 'extra2,444' >> output.csv", + "echo 'file-msg,string' >> output.csv", + ], + "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 2491e89e24..35bdb39486 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -37,29 +37,25 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,4.1' >> output.csv", - "echo 'throughput,512' >> output.csv", - "echo 'score,0.95' >> output.csv", - "echo '-------------------'", # This output does not go anywhere - "echo 'timestamp,metric,value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_results_file": "output.csv", - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,4.1' >> output.csv", + "echo 'throughput,512' >> output.csv", + "echo 'score,0.95' >> output.csv", + "echo '-------------------'", # This output does not go anywhere + "echo 'timestamp,metric,value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_results_file": "output.csv", + "read_telemetry_file": "telemetry.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 4.1, "throughput": 512.0, @@ -76,9 +72,7 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_no_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Read the telemetry data with no header. """ @@ -90,22 +84,18 @@ def test_local_env_telemetry_no_header( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - f"echo {time_str1},cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + f"echo {time_str1},cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={}, expected_telemetry=[ (ts1.astimezone(UTC), "cpu_load", 0.65), @@ -116,13 +106,9 @@ def test_local_env_telemetry_no_header( ) -@pytest.mark.filterwarnings( - "ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0" -) # pylint: disable=line-too-long # noqa +@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_wrong_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ Read the telemetry data with incorrect header. """ @@ -134,20 +120,17 @@ def test_local_env_telemetry_wrong_header( time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: the data is correct, but the header has unexpected column names - "echo 'ts,metric_name,metric_value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: the data is correct, but the header has unexpected column names + "echo 'ts,metric_name,metric_value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) @@ -165,19 +148,16 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: too many columns - f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: too many columns + f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) @@ -186,18 +166,15 @@ def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: """ Fail when the telemetry data has wrong format. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - # Error: field 1 must be a timestamp - "echo 1,cpu_load,0.65 > telemetry.csv", - "echo 2,mem_usage,10240 >> telemetry.csv", - "echo 3,cpu_load,0.8 >> telemetry.csv", - "echo 4,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # Error: field 1 must be a timestamp + "echo 1,cpu_load,0.65 > telemetry.csv", + "echo 2,mem_usage,10240 >> telemetry.csv", + "echo 3,cpu_load,0.8 >> telemetry.csv", + "echo 4,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }) check_env_fail_telemetry(local_env, tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 2b51ae1f0e..6cb4fd4f7e 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -16,22 +16,18 @@ def test_local_env(tunable_groups: TunableGroups) -> None: """ Produce benchmark and telemetry data in a local script and read it. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,10' >> output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,10' >> output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 10.0, "throughput": 66.0, @@ -45,7 +41,9 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """ Basic check that context support for Service mixins are handled when environment contexts are entered. """ - local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) + local_env = create_local_env(tunable_groups, { + "run": ["echo NA"] + }) # pylint: disable=protected-access assert local_env._service assert not local_env._service._in_context @@ -53,10 +51,10 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: with local_env as env_context: assert env_context._in_context assert local_env._service._in_context - assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) + assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) assert all(svc._in_context for svc in local_env._service._service_contexts) assert all(svc._in_context for svc in local_env._service._services) - assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) + assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) assert not local_env._service._service_contexts assert not any(svc._in_context for svc in local_env._service._services) @@ -65,18 +63,15 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: """ Fail if the results are not in the expected format. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - # No header - "echo 'latency,10' > output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + # No header + "echo 'latency,10' > output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) with local_env as env_context: assert env_context.setup(tunable_groups) @@ -88,20 +83,16 @@ def test_local_env_wide(tunable_groups: TunableGroups) -> None: """ Produce benchmark data in wide format and read it. """ - local_env = create_local_env( - tunable_groups, - { - "run": [ - "echo 'latency,throughput,score' > output.csv", - "echo '10,66,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }, - ) + local_env = create_local_env(tunable_groups, { + "run": [ + "echo 'latency,throughput,score' > output.csv", + "echo '10,66,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success( - local_env, - tunable_groups, + local_env, tunable_groups, expected_results={ "latency": 10, "throughput": 66, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index c6ece538f1..c16eac4459 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -18,30 +18,27 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: d """ Check that LocalEnv can set shell environment variables. """ - local_env = create_local_env( - tunable_groups, - { - "const_args": { - "const_arg": 111, # Passed into "shell_env_params" - "other_arg": 222, # NOT passed into "shell_env_params" - }, - "tunable_params": ["kernel"], - "shell_env_params": [ - "const_arg", # From "const_arg" - "kernel_sched_latency_ns", # From "tunable_params" - ], - "run": [ - "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", - f"echo {shell_subcmd} >> output.csv", - ], - "read_results_file": "output.csv", + local_env = create_local_env(tunable_groups, { + "const_args": { + "const_arg": 111, # Passed into "shell_env_params" + "other_arg": 222, # NOT passed into "shell_env_params" }, - ) + "tunable_params": ["kernel"], + "shell_env_params": [ + "const_arg", # From "const_arg" + "kernel_sched_latency_ns", # From "tunable_params" + ], + "run": [ + "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", + f"echo {shell_subcmd} >> output.csv", + ], + "read_results_file": "output.csv", + }) check_env_success(local_env, tunable_groups, expected, []) -@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only") +@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables in sh-like shell. @@ -50,15 +47,15 @@ def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", expected={ - "const_arg": 111, # From "const_args" - "other_arg": float("NaN"), # Not included in "shell_env_params" - "unknown_arg": float("NaN"), # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - }, + "const_arg": 111, # From "const_args" + "other_arg": float("NaN"), # Not included in "shell_env_params" + "unknown_arg": float("NaN"), # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + } ) -@pytest.mark.skipif(sys.platform != "win32", reason="Windows only") +@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: """ Check that LocalEnv can set shell environment variables on Windows / cmd shell. @@ -67,9 +64,9 @@ def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%", expected={ - "const_arg": 111, # From "const_args" - "other_arg": r"%other_arg%", # Not included in "shell_env_params" - "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - }, + "const_arg": 111, # From "const_args" + "other_arg": r"%other_arg%", # Not included in "shell_env_params" + "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + } ) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 25e75cf748..8bce053f7b 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -25,14 +25,13 @@ def mock_fileshare_service() -> MockFileShareService: """ return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, - parent=LocalExecService(parent=ConfigPersistenceService()), + parent=LocalExecService(parent=ConfigPersistenceService()) ) @pytest.fixture -def local_fileshare_env( - tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService -) -> LocalFileShareEnv: +def local_fileshare_env(tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: """ Create a LocalFileShareEnv instance. """ @@ -41,12 +40,12 @@ def local_fileshare_env( config={ "const_args": { "experiment_id": "EXP_ID", # Passed into "shell_env_params" - "trial_id": 222, # NOT passed into "shell_env_params" + "trial_id": 222, # NOT passed into "shell_env_params" }, "tunable_params": ["boot"], "shell_env_params": [ - "trial_id", # From "const_arg" - "idle", # From "tunable_params", == "halt" + "trial_id", # From "const_arg" + "idle", # From "tunable_params", == "halt" ], "upload": [ { @@ -58,7 +57,9 @@ def local_fileshare_env( "to": "$experiment_id/$trial_id/input/data_$idle.csv", }, ], - "run": ["echo No-op run"], + "run": [ + "echo No-op run" + ], "download": [ { "from": "$experiment_id/$trial_id/$idle/data.csv", @@ -72,11 +73,9 @@ def local_fileshare_env( return env -def test_local_fileshare_env( - tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService, - local_fileshare_env: LocalFileShareEnv, -) -> None: +def test_local_fileshare_env(tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, + local_fileshare_env: LocalFileShareEnv) -> None: """ Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the upload and download sections of the config. diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index c536c97a89..608edbf9ef 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -42,22 +42,20 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr assert data["score"] == pytest.approx(75.0, 0.01) -@pytest.mark.parametrize( - ("tunable_values", "expected_score"), - [ - ( - {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, - 66.4, - ), - ( - {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, - 74.06, - ), - ], -) -def test_mock_env_assign( - mock_env: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float -) -> None: +@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000 + }, 66.4), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000 + }, 74.06), +]) +def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the mock environment after the assignment. """ @@ -70,25 +68,21 @@ def test_mock_env_assign( assert data["score"] == pytest.approx(expected_score, 0.01) -@pytest.mark.parametrize( - ("tunable_values", "expected_score"), - [ - ( - {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, - 67.5, - ), - ( - {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, - 75.1, - ), - ], -) -def test_mock_env_no_noise_assign( - mock_env_no_noise: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, - expected_score: float, -) -> None: +@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 250000 + }, 67.5), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000 + }, 75.1), +]) +def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, expected_score: float) -> None: """ Check the benchmark values of the noiseless mock environment after the assignment. """ diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 6d47d1fc61..878531d799 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -38,31 +38,25 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: "ssh_priv_key_path": ssh_test_server.id_rsa_path, } - service = ConfigPersistenceService( - config={"config_path": [str(files("mlos_bench.tests.config"))]} - ) + service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") - env = service.load_environment( - config_path, TunableGroups(), global_config=global_config, service=service - ) + env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) check_env_success( - env, - env.tunable_params, + env, env.tunable_params, expected_results={ "hostname": ssh_test_server.service_name, "username": ssh_test_server.username, "score": 0.9, - "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" + "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" "test_param": "unset", "FOO": "unset", "ssh_username": "unset", }, expected_telemetry=[], ) - assert not os.path.exists( - os.path.join(os.getcwd(), "output-downloaded.csv") - ), "output-downloaded.csv should have been cleaned up by temp_dir context" + assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ + "output-downloaded.csv should have been cleaned up by temp_dir context" if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index b95666824a..377bc940a0 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -40,21 +40,16 @@ def __enter__(self) -> None: self.EVENT_LOOP_CONTEXT.enter() self._in_context = True - def __exit__( - self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType], - ) -> Literal[False]: + def __exit__(self, ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType]) -> Literal[False]: assert self._in_context self.EVENT_LOOP_CONTEXT.exit() self._in_context = False return False -@pytest.mark.filterwarnings( - "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" -) +@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") def test_event_loop_context() -> None: """Test event loop context background thread setup/cleanup handling.""" # pylint: disable=protected-access,too-many-statements @@ -92,16 +87,12 @@ def test_event_loop_context() -> None: assert event_loop_caller_instance_1._in_context assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 # We should only get one thread for all instances. - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread - ) - assert ( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop - ) assert not event_loop_caller_instance_2._in_context @@ -113,38 +104,30 @@ def test_event_loop_context() -> None: assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == "foo" + assert future.result(timeout=0.2) == 'foo' assert 0.1 <= time.time() - start <= 0.2 # Once we exit the last context, the background thread should be stopped # and unusable for running co-routines. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) + assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() # Check that the event loop has no more tasks. - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready") + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready') # Windows ProactorEventLoopPolicy adds a dummy task. - if sys.platform == "win32" and isinstance( - EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop - ): + if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled') assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 - with pytest.raises( - AssertionError - ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) + with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) raise ValueError(f"Future should not have been available to wait on {future.result()}") # Test that when re-entering the context we have the same event loop. @@ -155,14 +138,12 @@ def test_event_loop_context() -> None: # Test running again. start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( - asyncio.sleep(0.1, result="foo") - ) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == "foo" + assert future.result(timeout=0.2) == 'foo' assert 0.1 <= time.time() - start <= 0.2 -if __name__ == "__main__": +if __name__ == '__main__': # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 25abf659ce..90aa7e08f7 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -14,33 +14,19 @@ @pytest.mark.parametrize( - ("argv", "expected_score"), - [ - ( - [ - "--config", - "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", - "--trial_config_repeat_count", - "5", - "--mock_env_seed", - "-1", # Deterministic Mock Environment. - ], - 67.40329, - ), - ( - [ - "--config", - "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", - "--trial_config_repeat_count", - "3", - "--max_suggestions", - "3", - "--mock_env_seed", - "42", # Noisy Mock Environment. - ], - 64.53897, - ), - ], + ("argv", "expected_score"), [ + ([ + "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", + "--trial_config_repeat_count", "5", + "--mock_env_seed", "-1", # Deterministic Mock Environment. + ], 67.40329), + ([ + "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", + "--trial_config_repeat_count", "3", + "--max_suggestions", "3", + "--mock_env_seed", "42", # Noisy Mock Environment. + ], 64.53897), + ] ) def test_main_bench(argv: List[str], expected_score: float) -> None: """ diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index b03c5a2733..634050d099 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -48,8 +48,8 @@ def config_paths() -> List[str]: """ return [ path_join(os.getcwd(), abs_path=True), - str(files("mlos_bench.config")), - str(files("mlos_bench.tests.config")), + str(files('mlos_bench.config')), + str(files('mlos_bench.tests.config')), ] @@ -64,23 +64,20 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == "win32": + if sys.platform == 'win32': # Some env tweaks for platform compatibility. - environ["USER"] = environ["USERNAME"] + environ['USER'] = environ['USERNAME'] # This is part of the minimal required args by the Launcher. - env_conf_path = "environments/mock/mock_env.jsonc" - cli_args = ( - "--config-paths " - + " ".join(config_paths) - + " --service services/remote/mock/mock_auth_service.jsonc" - + " --service services/remote/mock/mock_remote_exec_service.jsonc" - + " --scheduler schedulers/sync_scheduler.jsonc" - + f" --environment {env_conf_path}" - + " --globals globals/global_test_config.jsonc" - + " --globals globals/global_test_extra_config.jsonc" - " --test_global_value_2 from-args" - ) + env_conf_path = 'environments/mock/mock_env.jsonc' + cli_args = '--config-paths ' + ' '.join(config_paths) + \ + ' --service services/remote/mock/mock_auth_service.jsonc' + \ + ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ + ' --scheduler schedulers/sync_scheduler.jsonc' + \ + f' --environment {env_conf_path}' + \ + ' --globals globals/global_test_config.jsonc' + \ + ' --globals globals/global_test_extra_config.jsonc' \ + ' --test_global_value_2 from-args' launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -88,28 +85,27 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsLocalExec) assert isinstance(launcher.service, SupportsRemoteExec) # Check that the first --globals file is loaded and $var expansion is handled. - assert launcher.global_config["experiment_id"] == "MockExperiment" - assert launcher.global_config["testVmName"] == "MockExperiment-vm" + assert launcher.global_config['experiment_id'] == 'MockExperiment' + assert launcher.global_config['testVmName'] == 'MockExperiment-vm' # Check that secondary expansion also works. - assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" + assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet' # Check that the second --globals file is loaded. - assert launcher.global_config["test_global_value"] == "from-file" + assert launcher.global_config['test_global_value'] == 'from-file' # Check overriding values in a file from the command line. - assert launcher.global_config["test_global_value_2"] == "from-args" + assert launcher.global_config['test_global_value_2'] == 'from-args' # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( - os.getcwd(), "foo", abs_path=True - ) - assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ + == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' assert launcher.teardown # Check that the environment that got loaded looks to be of the right type. env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config["class"]) + assert check_class_name(launcher.environment, env_config['class']) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. assert launcher.optimizer.tunable_params.is_defaults() - assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer + assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access @@ -126,25 +122,23 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == "win32": + if sys.platform == 'win32': # Some env tweaks for platform compatibility. - environ["USER"] = environ["USERNAME"] - - config_file = "cli/test-cli-config.jsonc" - globals_file = "globals/global_test_config.jsonc" - cli_args = ( - " ".join([f"--config-path {config_path}" for config_path in config_paths]) - + f" --config {config_file}" - + " --service services/remote/mock/mock_auth_service.jsonc" - + " --service services/remote/mock/mock_remote_exec_service.jsonc" - + f" --globals {globals_file}" - + " --experiment_id MockeryExperiment" - + " --no-teardown" - + " --random-init" - + " --random-seed 1234" - + " --trial-config-repeat-count 5" - + " --max_trials 200" - ) + environ['USER'] = environ['USERNAME'] + + config_file = 'cli/test-cli-config.jsonc' + globals_file = 'globals/global_test_config.jsonc' + cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \ + f' --config {config_file}' + \ + ' --service services/remote/mock/mock_auth_service.jsonc' + \ + ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ + f' --globals {globals_file}' + \ + ' --experiment_id MockeryExperiment' + \ + ' --no-teardown' + \ + ' --random-init' + \ + ' --random-seed 1234' + \ + ' --trial-config-repeat-count 5' + \ + ' --max_trials 200' launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -154,42 +148,35 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsRemoteExec) # Check that the --globals file is loaded and $var expansion is handled # using the value provided on the CLI. - assert launcher.global_config["experiment_id"] == "MockeryExperiment" - assert launcher.global_config["testVmName"] == "MockeryExperiment-vm" + assert launcher.global_config['experiment_id'] == 'MockeryExperiment' + assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm' # Check that secondary expansion also works. - assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet" + assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet' # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( - os.getcwd(), "foo", abs_path=True - ) - assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ + == path_join(os.getcwd(), "foo", abs_path=True) + assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' assert not launcher.teardown config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI) - assert launcher.config_loader.config_paths == [ - path_join(path, abs_path=True) for path in config_paths + config["config_path"] - ] + assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']] # Check that the environment that got loaded looks to be of the right type. - env_config_file = config["environment"] + env_config_file = config['environment'] env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config["class"]) + assert check_class_name(launcher.environment, env_config['class']) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) - opt_config_file = config["optimizer"] + opt_config_file = config['optimizer'] opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER) globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS) # The actual global_config gets overwritten as a part of processing, so to test # this we read the original value out of the source files. - orig_max_iters = globals_file_config.get( - "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100) - ) - assert ( - launcher.optimizer.max_iterations - == orig_max_iters - == launcher.global_config["max_suggestions"] - ) + orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100)) + assert launcher.optimizer.max_iterations \ + == orig_max_iters \ + == launcher.global_config['max_suggestions'] # Check that the optimizer got initialized with random values instead of the defaults. # Note: the environment doesn't get updated until suggest() is called to @@ -206,12 +193,12 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. - assert config["random_seed"] == 42 + assert config['random_seed'] == 42 # TODO: This isn't actually respected yet because the `--random-init` only # applies to a temporary Optimizer used to populate the initial values via # random sampling. # assert launcher.optimizer.seed == 1234 -if __name__ == "__main__": +if __name__ == '__main__': pytest.main([__file__, "-n1"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 8fff9b5dd5..591501d275 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -31,21 +31,16 @@ def local_exec_service() -> LocalExecService: """ Test fixture for LocalExecService. """ - return LocalExecService( - parent=ConfigPersistenceService( - { - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - } - ) - ) + return LocalExecService(parent=ConfigPersistenceService({ + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + })) -def _launch_main_app( - root_path: str, local_exec_service: LocalExecService, cli_config: str, re_expected: List[str] -) -> None: +def _launch_main_app(root_path: str, local_exec_service: LocalExecService, + cli_config: str, re_expected: List[str]) -> None: """ Run mlos_bench command-line application with given config and check the results in the log. @@ -57,13 +52,10 @@ def _launch_main_app( # temp_dir = '/tmp' log_path = path_join(temp_dir, "mock-test.log") (return_code, _stdout, _stderr) = local_exec_service.local_exec( - [ - "./mlos_bench/mlos_bench/run.py" - + " --config_path ./mlos_bench/mlos_bench/tests/config/" - + f" {cli_config} --log_file '{log_path}'" - ], - cwd=root_path, - ) + ["./mlos_bench/mlos_bench/run.py" + + " --config_path ./mlos_bench/mlos_bench/tests/config/" + + f" {cli_config} --log_file '{log_path}'"], + cwd=root_path) assert return_code == 0 try: @@ -87,34 +79,33 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ and default tunable values and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - " --config cli/mock-bench.jsonc" - + " --trial_config_repeat_count 5" - + " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, local_exec_service, + " --config cli/mock-bench.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", + ] ) def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService -) -> None: + root_path: str, local_exec_service: LocalExecService) -> None: """ Run mlos_bench command-line application with mock benchmark config and user-specified tunable values and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - " --config cli/mock-bench.jsonc" - + " --tunable_values tunable-values/tunable-values-example.jsonc" - + " --trial_config_repeat_count 5" - + " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, local_exec_service, + " --config cli/mock-bench.jsonc" + + " --tunable_values tunable-values/tunable-values-example.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", + ] ) @@ -124,23 +115,23 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic and check the results in the log. """ _launch_main_app( - root_path, - local_exec_service, - "--config cli/mock-opt.jsonc" - + " --trial_config_repeat_count 3" - + " --max_suggestions 3" - + " --mock_env_seed 42", # Noisy Mock Environment. + root_path, local_exec_service, + "--config cli/mock-opt.jsonc" + + " --trial_config_repeat_count 3" + + " --max_suggestions 3" + + " --mock_env_seed 42", # Noisy Mock Environment. [ # Iteration 1: Expect first value to be the baseline - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", # Iteration 2: The result may not always be deterministic - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Iteration 3: non-deterministic (depends on the optimizer) - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " - + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Final result: baseline is the optimum for the mock environment - f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", - ], + f"^{_RE_DATE} run\\.py:\\d+ " + + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", + ] ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 924224365c..59a0fac13b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -23,29 +23,29 @@ def mock_configs() -> List[dict]: """ return [ { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 50000, - "kernel_sched_latency_ns": 1000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'halt', + 'kernel_sched_migration_cost_ns': 50000, + 'kernel_sched_latency_ns': 1000000, }, { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000, - "kernel_sched_latency_ns": 2000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'halt', + 'kernel_sched_migration_cost_ns': 40000, + 'kernel_sched_latency_ns': 2000000, }, { - "vmSize": "Standard_B4ms", - "idle": "mwait", - "kernel_sched_migration_cost_ns": -1, # Special value - "kernel_sched_latency_ns": 3000000, + 'vmSize': 'Standard_B4ms', + 'idle': 'mwait', + 'kernel_sched_migration_cost_ns': -1, # Special value + 'kernel_sched_latency_ns': 3000000, }, { - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_latency_ns": 4000000, - }, + 'vmSize': 'Standard_B2s', + 'idle': 'mwait', + 'kernel_sched_migration_cost_ns': 200000, + 'kernel_sched_latency_ns': 4000000, + } ] @@ -61,7 +61,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: "optimization_targets": {"score": "min"}, "max_suggestions": 5, "start_with_defaults": False, - "seed": SEED, + "seed": SEED }, ) @@ -74,7 +74,11 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED}, + config={ + "optimization_targets": {"score": "min"}, + "max_suggestions": 5, + "seed": SEED + }, ) @@ -86,7 +90,11 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED}, + config={ + "optimization_targets": {"score": "max"}, + "max_suggestions": 10, + "seed": SEED + }, ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index add2945d74..9e9ce25d6f 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -20,7 +20,6 @@ # pylint: disable=redefined-outer-name - @pytest.fixture def grid_search_tunables_config() -> dict: """ @@ -52,22 +51,14 @@ def grid_search_tunables_config() -> dict: @pytest.fixture -def grid_search_tunables_grid( - grid_search_tunables: TunableGroups, -) -> List[Dict[str, TunableValue]]: +def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. Used to check that the grids are the same (ignoring order). """ - tunables_params_values = [ - tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None - ] - tunable_names = tuple( - tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None - ) - return list( - dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values) - ) + tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] + tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) + return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) @pytest.fixture @@ -79,9 +70,8 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture -def grid_search_opt( - grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]] -) -> GridSearchOptimizer: +def grid_search_opt(grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: """ Test fixture for grid search optimizer. """ @@ -89,20 +79,15 @@ def grid_search_opt( # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. max_iterations = len(grid_search_tunables_grid) * 2 - 3 - return GridSearchOptimizer( - tunables=grid_search_tunables, - config={ - "max_suggestions": max_iterations, - "optimization_targets": {"score": "max", "other_score": "min"}, - }, - ) + return GridSearchOptimizer(tunables=grid_search_tunables, config={ + "max_suggestions": max_iterations, + "optimization_targets": {"score": "max", "other_score": "min"}, + }) -def test_grid_search_grid( - grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]], -) -> None: +def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -129,11 +114,9 @@ def test_grid_search_grid( # assert grid_search_opt.pending_configs == grid_search_tunables_grid -def test_grid_search( - grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]], -) -> None: +def test_grid_search(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: """ Make sure that grid search optimizer initializes and works correctly. """ @@ -160,9 +143,7 @@ def test_grid_search( grid_search_tunables_grid.remove(default_config) assert default_config not in grid_search_opt.pending_configs assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all( - config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid - ) + assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) # The next suggestion should be a different element in the grid search. suggestion = grid_search_opt.suggest() @@ -176,9 +157,7 @@ def test_grid_search( grid_search_tunables_grid.remove(suggestion.get_param_values()) assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all( - config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid - ) + assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) # We consider not_converged as either having reached "max_suggestions" or an empty grid? @@ -244,7 +223,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: assert best_suggestion_dict not in grid_search_opt.suggested_configs best_suggestion_score: Dict[str, TunableValue] = {} - for opt_target, opt_dir in grid_search_opt.targets.items(): + for (opt_target, opt_dir) in grid_search_opt.targets.items(): val = score[opt_target] assert isinstance(val, (int, float)) best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 @@ -258,54 +237,36 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: # Check bulk register suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] - assert all( - suggestion.get_param_values() not in grid_search_opt.pending_configs - for suggestion in suggested - ) - assert all( - suggestion.get_param_values() in grid_search_opt.suggested_configs - for suggestion in suggested - ) + assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) + assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) # Those new suggestions also shouldn't be in the set of previously suggested configs. assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) - grid_search_opt.bulk_register( - [suggestion.get_param_values() for suggestion in suggested], - [score] * len(suggested), - [status] * len(suggested), - ) - - assert all( - suggestion.get_param_values() not in grid_search_opt.pending_configs - for suggestion in suggested - ) - assert all( - suggestion.get_param_values() not in grid_search_opt.suggested_configs - for suggestion in suggested - ) + grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], + [score] * len(suggested), + [status] * len(suggested)) + + assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) + assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) best_score, best_config = grid_search_opt.get_best_observation() assert best_score == best_suggestion_score assert best_config == best_suggestion -def test_grid_search_register( - grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups -) -> None: +def test_grid_search_register(grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups) -> None: """ Make sure that the `.register()` method adjusts the score signs correctly. """ assert grid_search_opt.register( - grid_search_tunables, - Status.SUCCEEDED, - { + grid_search_tunables, Status.SUCCEEDED, { "score": 1.0, "other_score": 2.0, - }, - ) == { - "score": -1.0, # max - "other_score": 2.0, # min + }) == { + "score": -1.0, # max + "other_score": 2.0, # min } assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 3a0ef7db2e..6549a8795c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -34,8 +34,7 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: "optimizer_type": "SMAC", "seed": SEED, # "start_with_defaults": False, - }, - ) + }) @pytest.fixture @@ -62,6 +61,6 @@ def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list assert best_score["score"] == pytest.approx(66.66, 0.01) -if __name__ == "__main__": +if __name__ == '__main__': # For attaching debugger debugging: pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__]) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index c824d9774f..7ebba0e664 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -24,9 +24,9 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: An instance of a mlos_core optimizer (FLAML-based). """ test_opt_config = { - "optimizer_type": "FLAML", - "max_suggestions": 10, - "seed": SEED, + 'optimizer_type': 'FLAML', + 'max_suggestions': 10, + 'seed': SEED, } return MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -39,44 +39,44 @@ def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) assert set(df_config.columns) == { - "kernel_sched_latency_ns", - "kernel_sched_migration_cost_ns", - "kernel_sched_migration_cost_ns!type", - "kernel_sched_migration_cost_ns!special", - "idle", - "vmSize", + 'kernel_sched_latency_ns', + 'kernel_sched_migration_cost_ns', + 'kernel_sched_migration_cost_ns!type', + 'kernel_sched_migration_cost_ns!special', + 'idle', + 'vmSize', } - assert df_config.to_dict(orient="records") == [ + assert df_config.to_dict(orient='records') == [ { - "idle": "halt", - "kernel_sched_latency_ns": 1000000, - "kernel_sched_migration_cost_ns": 50000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B4ms", + 'idle': 'halt', + 'kernel_sched_latency_ns': 1000000, + 'kernel_sched_migration_cost_ns': 50000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B4ms', }, { - "idle": "halt", - "kernel_sched_latency_ns": 2000000, - "kernel_sched_migration_cost_ns": 40000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B4ms", + 'idle': 'halt', + 'kernel_sched_latency_ns': 2000000, + 'kernel_sched_migration_cost_ns': 40000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B4ms', }, { - "idle": "mwait", - "kernel_sched_latency_ns": 3000000, - "kernel_sched_migration_cost_ns": None, # The value is special! - "kernel_sched_migration_cost_ns!special": -1, - "kernel_sched_migration_cost_ns!type": "special", - "vmSize": "Standard_B4ms", + 'idle': 'mwait', + 'kernel_sched_latency_ns': 3000000, + 'kernel_sched_migration_cost_ns': None, # The value is special! + 'kernel_sched_migration_cost_ns!special': -1, + 'kernel_sched_migration_cost_ns!type': 'special', + 'vmSize': 'Standard_B4ms', }, { - "idle": "mwait", - "kernel_sched_latency_ns": 4000000, - "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_migration_cost_ns!special": None, - "kernel_sched_migration_cost_ns!type": "range", - "vmSize": "Standard_B2s", + 'idle': 'mwait', + 'kernel_sched_latency_ns': 4000000, + 'kernel_sched_migration_cost_ns': 200000, + 'kernel_sched_migration_cost_ns!special': None, + 'kernel_sched_migration_cost_ns!type': 'range', + 'vmSize': 'Standard_B2s', }, ] diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index 9d696e01fa..fc62b4ff1b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -17,8 +17,8 @@ from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/" -_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test. +_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' +_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test. def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: @@ -26,10 +26,10 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) Test invalid max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "max_trials": 10, - "max_suggestions": 11, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'max_trials': 10, + 'max_suggestions': 11, + 'seed': SEED, } with pytest.raises(AssertionError): opt = MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -41,14 +41,14 @@ def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> No Test max_trials initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "max_suggestions": 123, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'max_suggestions': 123, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) - assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"] + assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions'] def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: @@ -57,9 +57,9 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": output_dir, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': output_dir, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) @@ -67,8 +67,7 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr assert isinstance(opt._opt, SmacOptimizer) # Final portions of the path are generated by SMAC when run_name is not specified. assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - str(test_opt_config["output_directory"]) - ) + str(test_opt_config['output_directory'])) shutil.rmtree(output_dir) @@ -77,67 +76,56 @@ def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGr Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config["output_directory"])) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']))) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_run_name( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "run_name": "test_run", - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'run_name': 'test_run', + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join( - os.getcwd(), str(test_opt_config["output_directory"]), str(test_opt_config["run_name"]) - ) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name']))) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( - tunable_groups: TunableGroups, -) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: """ Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": _OUTPUT_DIR, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': _OUTPUT_DIR, + 'seed': SEED, } global_config = { - "experiment_id": "experiment_id", + 'experiment_id': 'experiment_id', } opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join( - os.getcwd(), str(test_opt_config["output_directory"]), global_config["experiment_id"] - ) - ) + path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id'])) shutil.rmtree(_OUTPUT_DIR) @@ -146,9 +134,9 @@ def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups Test random output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - "optimizer_type": "SMAC", - "output_directory": None, - "seed": SEED, + 'optimizer_type': 'SMAC', + 'output_directory': None, + 'seed': SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index b95d943272..a94a315939 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -20,33 +20,24 @@ def mock_configurations_no_defaults() -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, - }, - 88.88, - ), - ( - { - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 117026, - "kernel_sched_latency_ns": 149827706, - }, - 66.66, - ), - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 354785, - "kernel_sched_latency_ns": 795285932, - }, - 99.99, - ), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 13112, + "kernel_sched_latency_ns": 796233790, + }, 88.88), + ({ + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 117026, + "kernel_sched_latency_ns": 149827706, + }, 66.66), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 354785, + "kernel_sched_latency_ns": 795285932, + }, 99.99), ] @@ -56,15 +47,12 @@ def mock_configurations(mock_configurations_no_defaults: list) -> list: A list of 2-tuples of (tunable_values, score) to test the optimizers. """ return [ - ( - { - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 2000000, - }, - 88.88, - ), + ({ + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": -1, + "kernel_sched_latency_ns": 2000000, + }, 88.88), ] + mock_configurations_no_defaults @@ -72,7 +60,7 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """ Run several iterations of the optimizer and return the best score. """ - for tunable_values, score in mock_configurations: + for (tunable_values, score) in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() assert tunables.get_param_values() == tunable_values @@ -92,9 +80,8 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_no_defaults( - mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list -) -> None: +def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, + mock_configurations_no_defaults: list) -> None: """ Make sure that mock optimizer produces consistent suggestions. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index ccc0ba8137..bf37040f13 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -25,7 +25,10 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: Same as `mock_config` above, but with all values converted to strings. (This can happen when we retrieve the data from storage). """ - return [{key: str(val) for (key, val) in config.items()} for config in mock_configs] + return [ + {key: str(val) for (key, val) in config.items()} + for config in mock_configs + ] @pytest.fixture @@ -49,12 +52,10 @@ def mock_status() -> List[Status]: return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] -def _test_opt_update_min( - opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None, -) -> None: +def _test_opt_update_min(opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None) -> None: """ Test the bulk update of the optimizer on the minimization problem. """ @@ -67,16 +68,14 @@ def _test_opt_update_min( "vmSize": "Standard_B4ms", "idle": "mwait", "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 3000000, + 'kernel_sched_latency_ns': 3000000, } -def _test_opt_update_max( - opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None, -) -> None: +def _test_opt_update_max(opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None) -> None: """ Test the bulk update of the optimizer on the maximization problem. """ @@ -89,16 +88,14 @@ def _test_opt_update_max( "vmSize": "Standard_B2s", "idle": "mwait", "kernel_sched_migration_cost_ns": 200000, - "kernel_sched_latency_ns": 4000000, + 'kernel_sched_latency_ns': 4000000, } -def test_update_mock_min( - mock_opt: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_min(mock_opt: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the minimization problem. """ @@ -108,76 +105,64 @@ def test_update_mock_min( "vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, + 'kernel_sched_latency_ns': 796233790, } -def test_update_mock_min_str( - mock_opt: MockOptimizer, - mock_configs_str: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_min_str(mock_opt: MockOptimizer, + mock_configs_str: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer with all-strings data. """ _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) -def test_update_mock_max( - mock_opt_max: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_mock_max(mock_opt_max: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the mock optimizer on the maximization problem. """ _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) -def test_update_flaml( - flaml_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_flaml(flaml_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) -def test_update_flaml_max( - flaml_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the FLAML optimizer. """ _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) -def test_update_smac( - smac_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_smac(smac_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the SMAC optimizer. """ _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) -def test_update_smac_max( - smac_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status], -) -> None: +def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status]) -> None: """ Test the bulk update of the SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index c30d1c32d2..2a50f95e8c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -56,7 +56,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: (status, _ts, output) = env_context.run() assert status.is_succeeded() assert output is not None - score = output["score"] + score = output['score'] assert isinstance(score, float) assert 60 <= score <= 120 logger("score: %s", str(score)) @@ -69,7 +69,8 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: return (best_score["score"], best_tunables) -def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimizer) -> None: +def test_mock_optimization_loop(mock_env_no_noise: MockEnv, + mock_opt: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -83,9 +84,8 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimi } -def test_mock_optimization_loop_no_defaults( - mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer -) -> None: +def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, + mock_opt_no_defaults: MockOptimizer) -> None: """ Toy optimization loop with mock environment and optimizer. """ @@ -99,7 +99,8 @@ def test_mock_optimization_loop_no_defaults( } -def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer) -> None: +def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, + flaml_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and FLAML optimizer. """ @@ -114,7 +115,8 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCore # @pytest.mark.skip(reason="SMAC is not deterministic") -def test_smac_optimization_loop(mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer) -> None: +def test_smac_optimization_loop(mock_env_no_noise: MockEnv, + smac_opt: MlosCoreOptimizer) -> None: """ Toy optimization loop with mock environment and SMAC optimizer. """ diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index bf4df0e6c2..1971c01799 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -11,8 +11,8 @@ from .remote import MockFileShareService, MockRemoteExecService, MockVMService __all__ = [ - "MockLocalExecService", - "MockFileShareService", - "MockRemoteExecService", - "MockVMService", + 'MockLocalExecService', + 'MockFileShareService', + 'MockRemoteExecService', + 'MockVMService', ] diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index 881b6b6cfa..d6cb869f09 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -29,19 +29,15 @@ def config_persistence_service() -> ConfigPersistenceService: """ Test fixture for ConfigPersistenceService. """ - return ConfigPersistenceService( - { - "config_path": [ - "./non-existent-dir/test/foo/bar", # Non-existent config path - ".", # cwd - str( - files("mlos_bench.tests.config").joinpath("") - ), # Test configs (relative to mlos_bench/tests) - # Shouldn't be necessary since we automatically add this. - # str(files("mlos_bench.config").joinpath("")), # Stock configs - ] - } - ) + return ConfigPersistenceService({ + "config_path": [ + "./non-existent-dir/test/foo/bar", # Non-existent config path + ".", # cwd + str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) + # Shouldn't be necessary since we automatically add this. + # str(files("mlos_bench.config").joinpath("")), # Stock configs + ] + }) def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: @@ -82,7 +78,7 @@ def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService assert os.path.exists(path) assert os.path.samefile( ConfigPersistenceService.BUILTIN_CONFIG_PATH, - os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]), + os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) ) @@ -110,9 +106,8 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No """ Check if we can successfully load a config file located relative to `config_path`. """ - tunables_data = config_persistence_service.load_config( - "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES - ) + tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", + ConfigSchema.TUNABLE_VALUES) assert tunables_data is not None assert isinstance(tunables_data, dict) assert len(tunables_data) >= 1 diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index a09fd442fb..c6dbf7c021 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -10,5 +10,5 @@ from .mock import MockLocalExecService __all__ = [ - "MockLocalExecService", + 'MockLocalExecService', ] diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index 78cebdf517..572195dcc5 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -56,12 +56,11 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: json.dump(params_meta, fh_meta) script_path = local_exec_service.config_loader_service.resolve_path( - "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py" - ) + "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py") - (return_code, _stdout, stderr) = local_exec_service.local_exec( - [f"{script_path} {input_file} {meta_file} {output_file}"], cwd=temp_dir, env=params - ) + (return_code, _stdout, stderr) = local_exec_service.local_exec([ + f"{script_path} {input_file} {meta_file} {output_file}" + ], cwd=temp_dir, env=params) assert stderr.strip() == "" assert return_code == 0 diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index c9dbecd93c..bd5b3b7d7f 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -24,27 +24,25 @@ def test_split_cmdline() -> None: """ Test splitting a commandline into subcommands. """ - cmdline = ( - ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" - ) + cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" assert list(split_cmdline(cmdline)) == [ - [".", "env.sh"], - ["&&"], - ["("], - ["echo", "hello"], - ["&&"], - ["echo", "world"], - ["|"], - ["tee"], - [">"], - ["/tmp/test"], - ["||"], - ["echo", "foo"], - ["&&"], - ["echo", "$var"], - [";"], - ["true"], - [")"], + ['.', 'env.sh'], + ['&&'], + ['('], + ['echo', 'hello'], + ['&&'], + ['echo', 'world'], + ['|'], + ['tee'], + ['>'], + ['/tmp/test'], + ['||'], + ['echo', 'foo'], + ['&&'], + ['echo', '$var'], + [';'], + ['true'], + [')'], ] @@ -69,10 +67,7 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: expected_cmdline = f". env.sh && {script_abspath} --input foo" subcmds_tokens = split_cmdline(orig_cmdline) # pylint: disable=protected-access - subcmds_tokens = [ - local_exec_service._resolve_cmdline_script_path(subcmd_tokens) - for subcmd_tokens in subcmds_tokens - ] + subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] expanded_cmdline = " ".join(cmdline_tokens) assert expanded_cmdline == expected_cmdline @@ -94,7 +89,10 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: Run a multiline script locally and check the results. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello", "echo world"]) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + "echo hello", + "echo world" + ]) assert return_code == 0 assert stdout.strip().split() == ["hello", "world"] assert stderr.strip() == "" @@ -105,12 +103,12 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: Run a multiline script locally and pass the environment variables to it. """ # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec( - [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd - env={"var": "VALUE", "int_var": 10}, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + r"echo $var", # Unix shell + r"echo %var%" # Windows cmd + ], env={"var": "VALUE", "int_var": 10}) assert return_code == 0 - if sys.platform == "win32": + if sys.platform == 'win32': assert stdout.strip().split() == ["$var", "VALUE"] else: assert stdout.strip().split() == ["VALUE", "%var%"] @@ -123,26 +121,23 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """ with local_exec_service.temp_dir_context() as temp_dir: - (return_code, stdout, stderr) = local_exec_service.local_exec( - [ - "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows - "echo '111,222' >> output.csv", - "echo '333,444' >> output.csv", - ], - cwd=temp_dir, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows + "echo '111,222' >> output.csv", + "echo '333,444' >> output.csv", + ], cwd=temp_dir) assert return_code == 0 assert stdout.strip() == "" assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) - if sys.platform == "win32": + if sys.platform == 'win32': # Workaround for Python's subprocess module on Windows adding a # space inbetween the col1,col2 arg and the redirect symbol which # cmd poorly interprets as being part of the original string arg. # Without this, we get "col2 " as the second column name. - data.rename(str.rstrip, axis="columns", inplace=True) + data.rename(str.rstrip, axis='columns', inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444]) @@ -157,13 +152,10 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input: fh_input.write("hello\n") - (return_code, stdout, stderr) = local_exec_service.local_exec( - [ - f"echo 'world' >> {input_file}", - f"echo 'test' >> {input_file}", - ], - cwd=temp_dir, - ) + (return_code, stdout, stderr) = local_exec_service.local_exec([ + f"echo 'world' >> {input_file}", + f"echo 'test' >> {input_file}", + ], cwd=temp_dir) assert return_code == 0 assert stdout.strip() == "" @@ -186,13 +178,11 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N """ Try to run a series of commands, one of which fails, and abort early. """ - (return_code, stdout, _stderr) = local_exec_service.local_exec( - [ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == "win32" else "false", - "echo world", - ] - ) + (return_code, stdout, _stderr) = local_exec_service.local_exec([ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", + "echo world", + ]) assert return_code != 0 assert stdout.strip() == "hello" @@ -202,13 +192,11 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No Try to run a series of commands, one of which fails, but let it pass. """ local_exec_service.abort_on_error = False - (return_code, stdout, _stderr) = local_exec_service.local_exec( - [ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == "win32" else "false", - "echo world", - ] - ) + (return_code, stdout, _stderr) = local_exec_service.local_exec([ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", + "echo world", + ]) assert return_code == 0 assert stdout.splitlines() == [ "hello", @@ -226,17 +214,13 @@ def test_temp_dir_path_expansion() -> None: # the fact. with tempfile.TemporaryDirectory() as temp_dir: global_config = { - "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" + "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" } config = { # The temp_dir for the LocalExecService should get expanded via workdir global config. "temp_dir": "$workdir/temp", } - local_exec_service = LocalExecService( - config, global_config, parent=ConfigPersistenceService() - ) + local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService()) # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) - assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( - temp_dir, "temp", abs_path=True - ) + assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True) diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index 9164da60df..eede9383bc 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -9,5 +9,5 @@ from .mock_local_exec_service import MockLocalExecService __all__ = [ - "MockLocalExecService", + 'MockLocalExecService', ] diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index 3df89aaed9..db8f0134c4 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -35,21 +35,16 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): Mock methods for LocalExecService testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, global_config, parent, self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, + self.merge_methods(methods, [self.local_exec]) ) - def local_exec( - self, - script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None, - ) -> Tuple[int, str, str]: + def local_exec(self, script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None) -> Tuple[int, str, str]: return (0, "", "") diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 4ef38ab440..835738015b 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -28,24 +28,19 @@ class MockServiceBase(Service, SupportsSomeMethod): """A base service class for testing.""" def __init__( - self, - config: Optional[dict] = None, - global_config: Optional[dict] = None, - parent: Optional[Service] = None, - methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None, - ) -> None: + self, + config: Optional[dict] = None, + global_config: Optional[dict] = None, + parent: Optional[Service] = None, + methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None: super().__init__( config, global_config, parent, - self.merge_methods( - methods, - [ - self.some_method, - self.some_other_method, - ], - ), - ) + self.merge_methods(methods, [ + self.some_method, + self.some_other_method, + ])) def some_method(self) -> str: """some_method""" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index df3fb69c53..e8a87ab684 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -12,7 +12,7 @@ from .mock.mock_vm_service import MockVMService __all__ = [ - "MockFileShareService", - "MockRemoteExecService", - "MockVMService", + 'MockFileShareService', + 'MockRemoteExecService', + 'MockVMService', ] diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index d451370b63..c6475e6936 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -18,9 +18,7 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_file( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService -) -> None: +def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" @@ -28,9 +26,8 @@ def test_download_file( local_path = f"{local_folder}/{filename}" mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client: + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ + patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -50,41 +47,38 @@ def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock( - return_value=[ - {"name": "a_folder", "is_directory": True}, - {"name": "a_file_1.csv", "is_directory": False}, - ] - ), + list_directories_and_files=Mock(return_value=[ + {"name": "a_folder", "is_directory": True}, + {"name": "a_file_1.csv", "is_directory": False}, + ]) ), f"{remote_folder}/a_folder": Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock( - return_value=[ - {"name": "a_file_2.csv", "is_directory": False}, - ] - ), + list_directories_and_files=Mock(return_value=[ + {"name": "a_file_2.csv", "is_directory": False}, + ]) + ), + f"{remote_folder}/a_file_1.csv": Mock( + exists=Mock(return_value=False) + ), + f"{remote_folder}/a_folder/a_file_2.csv": Mock( + exists=Mock(return_value=False) ), - f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)), - f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)), } @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService -) -> None: +def test_download_folder_non_recursive(mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client, patch.object( - mock_share_client, "get_file_client" - ) as mock_get_file_client: + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] @@ -93,63 +87,47 @@ def test_download_folder_non_recursive( mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_get_directory_client.assert_has_calls( - [ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - ], - any_order=True, - ) + mock_get_directory_client.assert_has_calls([ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + ], any_order=True) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService -) -> None: +def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object( - mock_share_client, "get_directory_client" - ) as mock_get_directory_client, patch.object( - mock_share_client, "get_file_client" - ) as mock_get_file_client: + with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ + patch.object(mock_share_client, "get_file_client") as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] azure_fileshare.download(config, remote_folder, local_folder, recursive=True) - mock_get_file_client.assert_has_calls( - [ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) - mock_get_directory_client.assert_has_calls( - [ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) + mock_get_file_client.assert_has_calls([ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) + mock_get_directory_client.assert_has_calls([ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file( - mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService -) -> None: +def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False config: dict = {} @@ -165,7 +143,6 @@ def test_upload_file( class MyDirEntry: # pylint: disable=too-few-public-methods """Dummy class for os.DirEntry""" - def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir @@ -209,19 +186,17 @@ def process_paths(input_path: str) -> str: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive( - mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_upload_directory_non_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: @@ -233,28 +208,23 @@ def test_upload_directory_non_recursive( @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive( - mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService, -) -> None: +def test_upload_directory_recursive(mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) - mock_get_file_client.assert_has_calls( - [ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], - any_order=True, - ) + mock_get_file_client.assert_has_calls([ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], any_order=True) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index af239a158e..d6d55d3975 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -18,20 +18,16 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_network_deployment_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_network_service: AzureNetworkService, -) -> None: +def test_wait_network_deployment_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_network_service: AzureNetworkService) -> None: """ Test retries of the network deployment operation. """ @@ -39,12 +35,8 @@ def test_wait_network_deployment_retry( # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -57,37 +49,30 @@ def test_wait_network_deployment_retry( "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True, - ) + is_setup=True) assert status == operation_status @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("deprovision_network", True), - ], -) + ]) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), # These should succeed since we set ignore_errors=True by default (401, Status.SUCCEEDED), (404, Status.SUCCEEDED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_network_operation_status( - mock_requests: MagicMock, - azure_network_service: AzureNetworkService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status, -) -> None: +def test_network_operation_status(mock_requests: MagicMock, + azure_network_service: AzureNetworkService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status) -> None: """ Test network operation status. """ @@ -104,30 +89,22 @@ def test_network_operation_status( @pytest.fixture -def test_azure_network_service_no_deployment_template( - azure_auth_service: AzureAuthService, -) -> None: +def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: """ Tests creating a network services without a deployment template (should fail). """ with pytest.raises(ValueError): - _ = AzureNetworkService( - config={ - "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", - }, + _ = AzureNetworkService(config={ + "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + }, parent=azure_auth_service) with pytest.raises(ValueError): - _ = AzureNetworkService( - config={ - # "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", - }, + _ = AzureNetworkService(config={ + # "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + }, parent=azure_auth_service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index fc72131c0c..1d84d73cab 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -19,20 +19,16 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_host_deployment_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService, -) -> None: +def test_wait_host_deployment_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService) -> None: """ Test retries of the host deployment operation. """ @@ -40,12 +36,8 @@ def test_wait_host_deployment_retry( # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -58,8 +50,7 @@ def test_wait_host_deployment_retry( "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True, - ) + is_setup=True) assert status == operation_status @@ -84,14 +75,8 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut } azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) assert azure_vm_service.deploy_params["location"] == global_config["location"] - assert ( - azure_vm_service.deploy_params["vmMeta"] - == f'{global_config["vmName"]}-{global_config["location"]}' - ) - assert ( - azure_vm_service.deploy_params["vmNsg"] - == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' - ) + assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' + assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: @@ -113,15 +98,14 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N } with pytest.raises(ValueError): config_with_custom_data = deepcopy(config) - config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] + config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) - assert azure_vm_service.deploy_params["customData"] + assert azure_vm_service.deploy_params['customData'] @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("start_host", True), ("stop_host", True), ("shutdown", True), @@ -129,27 +113,22 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N ("deallocate_host", True), ("restart_host", True), ("reboot", True), - ], -) + ]) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_vm_operation_status( - mock_requests: MagicMock, - azure_vm_service: AzureVMService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status, -) -> None: +def test_vm_operation_status(mock_requests: MagicMock, + azure_vm_service: AzureVMService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status) -> None: """ Test VM operation status. """ @@ -166,14 +145,12 @@ def test_vm_operation_status( @pytest.mark.parametrize( - ("operation_name", "accepts_params"), - [ + ("operation_name", "accepts_params"), [ ("provision_host", True), - ], -) -def test_vm_operation_invalid( - azure_vm_service_remote_exec_only: AzureVMService, operation_name: str, accepts_params: bool -) -> None: + ]) +def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, + operation_name: str, + accepts_params: bool) -> None: """ Test VM operation status for an incomplete service config. """ @@ -184,9 +161,8 @@ def test_vm_operation_invalid( @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_ready( - mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService -) -> None: +def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, + azure_vm_service: AzureVMService) -> None: """ Test waiting for the completion of the remote VM operation. """ @@ -207,20 +183,23 @@ def test_wait_vm_operation_ready( status, _ = azure_vm_service.wait_host_operation(params) - assert (async_url,) == mock_session.return_value.get.call_args[0] - assert (retry_after,) == mock_sleep.call_args[0] + assert (async_url, ) == mock_session.return_value.get.call_args[0] + assert (retry_after, ) == mock_sleep.call_args[0] assert status.is_succeeded() @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_timeout( - mock_session: MagicMock, azure_vm_service: AzureVMService -) -> None: +def test_wait_vm_operation_timeout(mock_session: MagicMock, + azure_vm_service: AzureVMService) -> None: """ Test the time out of the remote VM operation. """ # Mock response header - params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1} + params = { + "asyncResultsUrl": "DUMMY_ASYNC_URL", + "vmName": "test-vm", + "pollInterval": 1 + } mock_status_response = MagicMock(status_code=200) mock_status_response.json.return_value = { @@ -233,20 +212,16 @@ def test_wait_vm_operation_timeout( @pytest.mark.parametrize( - ("total_retries", "operation_status"), - [ + ("total_retries", "operation_status"), [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ], -) + ]) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_vm_operation_retry( - mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService, -) -> None: +def test_wait_vm_operation_retry(mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService) -> None: """ Test the retries of the remote VM operation. """ @@ -254,12 +229,8 @@ def test_wait_vm_operation_retry( # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), - requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") - ), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), ] @@ -270,27 +241,20 @@ def test_wait_vm_operation_retry( "requestTotalRetries": total_retries, "asyncResultsUrl": "https://DUMMY_ASYNC_URL", "vmName": "test-vm", - } - ) + }) assert status == operation_status @pytest.mark.parametrize( - ("http_status_code", "operation_status"), - [ + ("http_status_code", "operation_status"), [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ], -) + ]) @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_status( - mock_requests: MagicMock, - azure_vm_service_remote_exec_only: AzureVMService, - http_status_code: int, - operation_status: Status, -) -> None: +def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, + http_status_code: int, operation_status: Status) -> None: """ Test waiting for completion of the remote execution on Azure. """ @@ -298,24 +262,19 @@ def test_remote_exec_status( mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock( - return_value={ - "fake response": "body as json to dict", - } - ) + mock_response.json = MagicMock(return_value={ + "fake response": "body as json to dict", + }) mock_requests.post.return_value = mock_response - status, _ = azure_vm_service_remote_exec_only.remote_exec( - script, config={"vmName": "test-vm"}, env_params={} - ) + status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) assert status == operation_status @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output( - mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService -) -> None: +def test_remote_exec_headers_output(mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService) -> None: """ Check if HTTP headers from the remote execution on Azure are correct. """ @@ -325,22 +284,18 @@ def test_remote_exec_headers_output( mock_response = MagicMock() mock_response.status_code = 202 - mock_response.headers = {"Azure-AsyncOperation": async_url_value} - mock_response.json = MagicMock( - return_value={ - "fake response": "body as json to dict", - } - ) + mock_response.headers = { + "Azure-AsyncOperation": async_url_value + } + mock_response.json = MagicMock(return_value={ + "fake response": "body as json to dict", + }) mock_requests.post.return_value = mock_response - _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( - script, - config={"vmName": "test-vm"}, - env_params={ - "param_1": 123, - "param_2": "abc", - }, - ) + _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ + "param_1": 123, + "param_2": "abc", + }) assert async_url_key in cmd_output assert cmd_output[async_url_key] == async_url_value @@ -348,13 +303,15 @@ def test_remote_exec_headers_output( assert mock_requests.post.call_args[1]["json"] == { "commandId": "RunShellScript", "script": script, - "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}], + "parameters": [ + {"name": "param_1", "value": 123}, + {"name": "param_2", "value": "abc"} + ] } @pytest.mark.parametrize( - ("operation_status", "wait_output", "results_output"), - [ + ("operation_status", "wait_output", "results_output"), [ ( Status.SUCCEEDED, { @@ -366,18 +323,13 @@ def test_remote_exec_headers_output( } } }, - {"stdout": "DUMMY_STDOUT_STDERR"}, + {"stdout": "DUMMY_STDOUT_STDERR"} ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), - ], -) -def test_get_remote_exec_results( - azure_vm_service_remote_exec_only: AzureVMService, - operation_status: Status, - wait_output: dict, - results_output: dict, -) -> None: + ]) +def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, + wait_output: dict, results_output: dict) -> None: """ Test getting the results of the remote execution on Azure. """ diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 6a2d62267b..2794bb01cf 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -30,9 +30,8 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture -def azure_auth_service( - config_persistence_service: ConfigPersistenceService, monkeypatch: pytest.MonkeyPatch -) -> AzureAuthService: +def azure_auth_service(config_persistence_service: ConfigPersistenceService, + monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: """ Creates a dummy AzureAuthService for tests that require it. """ @@ -46,23 +45,19 @@ def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkS """ Creates a dummy Azure VM service for tests that require it. """ - return AzureNetworkService( - config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", - }, - "pollInterval": 1, - "pollTimeout": 2, + return AzureNetworkService(config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", }, - global_config={ - "deploymentName": "TEST_DEPLOYMENT-VNET", - "vnetName": "test-vnet", # Should come from the upper-level config - }, - parent=azure_auth_service, - ) + "pollInterval": 1, + "pollTimeout": 2 + }, global_config={ + "deploymentName": "TEST_DEPLOYMENT-VNET", + "vnetName": "test-vnet", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture @@ -70,23 +65,19 @@ def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: """ Creates a dummy Azure VM service for tests that require it. """ - return AzureVMService( - config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", - }, - "pollInterval": 1, - "pollTimeout": 2, - }, - global_config={ - "deploymentName": "TEST_DEPLOYMENT-VM", - "vmName": "test-vm", # Should come from the upper-level config + return AzureVMService(config={ + "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", }, - parent=azure_auth_service, - ) + "pollInterval": 1, + "pollTimeout": 2 + }, global_config={ + "deploymentName": "TEST_DEPLOYMENT-VM", + "vmName": "test-vm", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture @@ -94,18 +85,14 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A """ Creates a dummy Azure VM service with no deployment template. """ - return AzureVMService( - config={ - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "pollInterval": 1, - "pollTimeout": 2, - }, - global_config={ - "vmName": "test-vm", # Should come from the upper-level config - }, - parent=azure_auth_service, - ) + return AzureVMService(config={ + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "pollInterval": 1, + "pollTimeout": 2, + }, global_config={ + "vmName": "test-vm", # Should come from the upper-level config + }, parent=azure_auth_service) @pytest.fixture @@ -114,12 +101,8 @@ def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> Azu Creates a dummy AzureFileShareService for tests that require it. """ with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): - return AzureFileShareService( - config={ - "storageAccountName": "TEST_ACCOUNT_NAME", - "storageFileShareName": "TEST_FS_NAME", - "storageAccountKey": "TEST_ACCOUNT_KEY", - }, - global_config={}, - parent=config_persistence_service, - ) + return AzureFileShareService(config={ + "storageAccountName": "TEST_ACCOUNT_NAME", + "storageFileShareName": "TEST_FS_NAME", + "storageAccountKey": "TEST_ACCOUNT_KEY" + }, global_config={}, parent=config_persistence_service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index fb1c4ee39b..b9474f0709 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -20,24 +20,16 @@ class MockAuthService(Service, SupportsAuth): A collection Service functions for mocking authentication ops. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - [ - self.get_access_token, - self.get_auth_headers, - ], - ), + config, global_config, parent, + self.merge_methods(methods, [ + self.get_access_token, + self.get_auth_headers, + ]) ) def get_access_token(self) -> str: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 79f8c608c2..1a026966a8 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -21,30 +21,21 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): A collection Service functions for mocking file share ops. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): super().__init__( - config, - global_config, - parent, - self.merge_methods(methods, [self.upload, self.download]), + config, global_config, parent, + self.merge_methods(methods, [self.upload, self.download]) ) self._upload: List[Tuple[str, str]] = [] self._download: List[Tuple[str, str]] = [] - def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True - ) -> None: + def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: self._upload.append((local_path, remote_path)) - def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True - ) -> None: + def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index 6bf9fc8d05..e6169d9f93 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -20,13 +20,10 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): Mock Network service for testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock network services proxy. @@ -41,19 +38,13 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - name: mock_operation - for name in ( - # SupportsNetworkProvisioning: - "provision_network", - "deprovision_network", - "wait_network_deployment", - ) - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + name: mock_operation for name in ( + # SupportsNetworkProvisioning: + "provision_network", + "deprovision_network", + "wait_network_deployment", + ) + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index 38d759f53c..ee99251c64 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -18,13 +18,10 @@ class MockRemoteExecService(Service, SupportsRemoteExec): Mock remote script execution service. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock remote exec service. @@ -39,14 +36,9 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - "remote_exec": mock_operation, - "get_remote_exec_results": mock_operation, - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + "remote_exec": mock_operation, + "get_remote_exec_results": mock_operation, + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index 3ae13cf6a6..a44edaf080 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -20,13 +20,10 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, Supports Mock VM service for testing. """ - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None, - ): + def __init__(self, config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None): """ Create a new instance of mock VM services proxy. @@ -41,29 +38,23 @@ def __init__( Parent service that can provide mixin functions. """ super().__init__( - config, - global_config, - parent, - self.merge_methods( - methods, - { - name: mock_operation - for name in ( - # SupportsHostProvisioning: - "wait_host_deployment", - "provision_host", - "deprovision_host", - "deallocate_host", - # SupportsHostOps: - "start_host", - "stop_host", - "restart_host", - "wait_host_operation", - # SupportsOsOps: - "shutdown", - "reboot", - "wait_os_operation", - ) - }, - ), + config, global_config, parent, + self.merge_methods(methods, { + name: mock_operation for name in ( + # SupportsHostProvisioning: + "wait_host_deployment", + "provision_host", + "deprovision_host", + "deallocate_host", + # SupportsHostOps: + "start_host", + "stop_host", + "restart_host", + "wait_host_operation", + # SupportsOsOps: + "shutdown", + "reboot", + "wait_os_operation", + ) + }) ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index 16c88dc791..e0060d8047 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -17,9 +17,9 @@ # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 -SSH_TEST_SERVER_NAME = "ssh-server" -ALT_TEST_SERVER_NAME = "alt-server" -REBOOT_TEST_SERVER_NAME = "reboot-server" +SSH_TEST_SERVER_NAME = 'ssh-server' +ALT_TEST_SERVER_NAME = 'alt-server' +REBOOT_TEST_SERVER_NAME = 'reboot-server' @dataclass @@ -42,12 +42,8 @@ def get_port(self, uncached: bool = False) -> int: Note: this value can change when the service restarts so we can't rely on the DockerServices. """ if self._port is None or uncached: - port_cmd = run( - f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", - shell=True, - check=True, - capture_output=True, - ) + port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", + shell=True, check=True, capture_output=True) self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) return self._port diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 1d9f570fdf..6f05fe953b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -30,28 +30,26 @@ # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = "host.docker.internal" +HOST_DOCKER_NAME = 'host.docker.internal' @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): # On Linux, if we're running in a docker container, we can use the # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME # Docker (Desktop) for Windows (WSL2) uses a special networking magic # to refer to the host machine as `localhost` when exposing ports. # In all other cases, assume we're executing directly inside conda on the host. - return "localhost" + return 'localhost' @pytest.fixture(scope="session") -def ssh_test_server( - ssh_test_server_hostname: str, - docker_compose_project_name: str, - locked_docker_services: DockerServices, -) -> Generator[SshTestServerInfo, None, None]: +def ssh_test_server(ssh_test_server_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: """ Fixture for getting the ssh test server services setup via docker-compose using pytest-docker. @@ -68,35 +66,23 @@ def ssh_test_server( compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, hostname=ssh_test_server_hostname, - username="root", - id_rsa_path=id_rsa_file.name, - ) - wait_docker_service_socket( - locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port() - ) + username='root', + id_rsa_path=id_rsa_file.name) + wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" - cmd = run( - docker_cp_cmd.split(), - check=True, - cwd=os.path.dirname(__file__), - capture_output=True, - text=True, - ) + cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) if cmd.returncode != 0: - raise RuntimeError( - f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " - + f"[return={cmd.returncode}]: {str(cmd.stderr)}" - ) + raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " + + f"[return={cmd.returncode}]: {str(cmd.stderr)}") os.chmod(id_rsa_file.name, 0o600) yield ssh_test_server_info # NamedTempFile deleted on context exit @pytest.fixture(scope="session") -def alt_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices -) -> SshTestServerInfo: +def alt_test_server(ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -109,18 +95,14 @@ def alt_test_server( service_name=ALT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path, - ) - wait_docker_service_socket( - locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port() - ) + id_rsa_path=ssh_test_server.id_rsa_path) + wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) return alt_test_server_info @pytest.fixture(scope="session") -def reboot_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices -) -> SshTestServerInfo: +def reboot_test_server(ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. See additional notes in the ssh_test_server fixture above. @@ -133,13 +115,8 @@ def reboot_test_server( service_name=REBOOT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path, - ) - wait_docker_service_socket( - locked_docker_services, - reboot_test_server_info.hostname, - reboot_test_server_info.get_port(), - ) + id_rsa_path=ssh_test_server.id_rsa_path) + wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port()) return reboot_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index c77c57def8..f2bbbe4b8a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -52,9 +52,8 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker -def test_ssh_fileshare_single_file( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -67,7 +66,7 @@ def test_ssh_fileshare_single_file( lines = [line + "\n" for line in lines] # 1. Write a local file and upload it. - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines(lines) temp_file.flush() temp_file.close() @@ -79,7 +78,7 @@ def test_ssh_fileshare_single_file( ) # 2. Download the remote file and compare the contents. - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.close() ssh_fileshare_service.download( params=config, @@ -87,15 +86,14 @@ def test_ssh_fileshare_single_file( local_path=temp_file.name, ) # Download will replace the inode at that name, so we need to reopen the file. - with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == lines @requires_docker -def test_ssh_fileshare_recursive( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -115,16 +113,14 @@ def test_ssh_fileshare_recursive( "bar", ], } - files_lines = { - path: [line + "\n" for line in lines] for (path, lines) in files_lines.items() - } + files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: # Setup the directory structure. - for file_path, lines in files_lines.items(): + for (file_path, lines) in files_lines.items(): path = Path(tempdir1, file_path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode="w+t", encoding="utf-8") as temp_file: + with open(path, mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines(lines) temp_file.flush() assert os.path.getsize(path) > 0 @@ -151,16 +147,15 @@ def test_ssh_fileshare_recursive( @requires_docker -def test_ssh_fileshare_download_file_dne( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService -) -> None: +def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() canary_str = "canary" - with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: + with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: temp_file.writelines([canary_str]) temp_file.flush() temp_file.close() @@ -171,22 +166,20 @@ def test_ssh_fileshare_download_file_dne( remote_path="/tmp/file-dne.txt", local_path=temp_file.name, ) - with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: + with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == [canary_str] @requires_docker -def test_ssh_fileshare_upload_file_dne( - ssh_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - ssh_fileshare_service: SshFileShareService, -) -> None: +def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService) -> None: """Test the SshFileShareService single file upload that doesn't exist.""" with ssh_host_service, ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() - path = "/tmp/upload-file-src-dne.txt" + path = '/tmp/upload-file-src-dne.txt' with pytest.raises(OSError): ssh_fileshare_service.upload( params=config, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 40a9d4ae74..4c8e5e0c66 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -27,11 +27,9 @@ @requires_docker -def test_ssh_service_remote_exec( - ssh_test_server: SshTestServerInfo, - alt_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, -) -> None: +def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService) -> None: """ Test the SshHostService remote_exec. @@ -44,9 +42,7 @@ def test_ssh_service_remote_exec( connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( - connection_id - ) + connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) assert connection_client is None (status, results_info) = ssh_host_service.remote_exec( @@ -61,9 +57,7 @@ def test_ssh_service_remote_exec( assert results["stdout"].strip() == SSH_TEST_SERVER_NAME # Check that the client caching is behaving as expected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ - connection_id - ] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection is not None assert connection._username == ssh_test_server.username assert connection._host == ssh_test_server.hostname @@ -97,15 +91,13 @@ def test_ssh_service_remote_exec( }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) - assert status.is_failed() # should retain exit code from "false" + assert status.is_failed() # should retain exit code from "false" stdout = str(results["stdout"]) assert stdout.splitlines() == [ "BAR=bar", "UNUSED=", ] - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ - connection_id - ] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection._local_port == local_port # Close the connection (gracefully) @@ -122,7 +114,7 @@ def test_ssh_service_remote_exec( config=config, # Also test interacting with environment_variables. env_params={ - "FOO": "foo", + 'FOO': 'foo', }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) @@ -135,21 +127,17 @@ def test_ssh_service_remote_exec( "BAZ=", ] # Make sure it looks like we reconnected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ - connection_id - ] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] assert connection._local_port != local_port # Make sure the cache is cleaned up on context exit. assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 -def check_ssh_service_reboot( - docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - graceful: bool, -) -> None: +def check_ssh_service_reboot(docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool) -> None: """ Check the SshHostService reboot operation. """ @@ -160,7 +148,11 @@ def check_ssh_service_reboot( with ssh_host_service: reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) (status, results_info) = ssh_host_service.remote_exec( - script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'], + script=[ + 'echo "sleeping..."', + 'sleep 30', + 'echo "should not reach this point"' + ], config=reboot_test_srv_ssh_svc_conf, env_params={}, ) @@ -169,9 +161,8 @@ def check_ssh_service_reboot( time.sleep(1) # Now try to restart the server. - (status, reboot_results_info) = ssh_host_service.reboot( - params=reboot_test_srv_ssh_svc_conf, force=not graceful - ) + (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf, + force=not graceful) assert status.is_pending() (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) @@ -192,34 +183,19 @@ def check_ssh_service_reboot( time.sleep(1) # try to reconnect and see if the port changed try: - run_res = run( - "docker ps | grep mlos_bench-test- | grep reboot", - shell=True, - capture_output=True, - check=False, - ) + run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False) print(run_res.stdout.decode()) print(run_res.stderr.decode()) - reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config( - uncached=True - ) - if ( - reboot_test_srv_ssh_svc_conf_new["ssh_port"] - != reboot_test_srv_ssh_svc_conf["ssh_port"] - ): + reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True) + if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]: break except CalledProcessError as ex: _LOG.info("Failed to check port for reboot test server: %s", ex) - assert ( - reboot_test_srv_ssh_svc_conf_new["ssh_port"] - != reboot_test_srv_ssh_svc_conf["ssh_port"] - ) + assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"] - wait_docker_service_socket( - docker_services, - reboot_test_server.hostname, - reboot_test_srv_ssh_svc_conf_new["ssh_port"], - ) + wait_docker_service_socket(docker_services, + reboot_test_server.hostname, + reboot_test_srv_ssh_svc_conf_new["ssh_port"]) (status, results_info) = ssh_host_service.remote_exec( script=["hostname"], @@ -232,18 +208,12 @@ def check_ssh_service_reboot( @requires_docker -def test_ssh_service_reboot( - locked_docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, -) -> None: +def test_ssh_service_reboot(locked_docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService) -> None: """ Test the SshHostService reboot operation. """ # Grouped together to avoid parallel runner interactions. - check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=True - ) - check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=False - ) + check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) + check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index ee9f310510..7bee929fea 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -35,9 +35,7 @@ # We replaced pytest-lazy-fixture with pytest-lazy-fixtures: # https://github.com/TvoroG/pytest-lazy-fixture/issues/65 if version("pytest-lazy-fixture"): - raise UserWarning( - "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it." - ) + raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.") except PackageNotFoundError: # OK: pytest-lazy-fixture not installed pass @@ -45,14 +43,12 @@ @requires_docker @requires_ssh -@pytest.mark.parametrize( - ["ssh_test_server_info", "server_name"], - [ - (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), - (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), - ], -) -def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, server_name: str) -> None: +@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), +]) +def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, + server_name: str) -> None: """Check for the pytest-docker ssh test infra.""" assert ssh_test_server_info.service_name == server_name @@ -61,18 +57,17 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, server_ local_port = ssh_test_server_info.get_port() assert check_socket(ip_addr, local_port) - ssh_cmd = ( - "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " - + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " + ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + f"-p {local_port} {ssh_test_server_info.hostname} hostname" - ) - cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True) + cmd = run(ssh_cmd.split(), + capture_output=True, + text=True, + check=True) assert cmd.stdout.strip() == server_name -@pytest.mark.filterwarnings( - "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" -) +@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. @@ -105,23 +100,17 @@ def test_ssh_service_context_handler() -> None: with ssh_fileshare_service: assert ssh_fileshare_service._in_context assert ssh_host_service._in_context - assert ( - SshService._EVENT_LOOP_CONTEXT._event_loop_thread - is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread + assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread - ) - assert ( - SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE - is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE - ) assert not ssh_fileshare_service._in_context # And that instance should be unusable after we are outside the context. - with pytest.raises( - AssertionError - ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result="foo")) + with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo')) raise ValueError(f"Future should not have been available to wait on {future.result()}") # The background thread should remain running since we have another context still open. @@ -129,6 +118,6 @@ def test_ssh_service_context_handler() -> None: assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None -if __name__ == "__main__": +if __name__ == '__main__': # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 20320042ee..2c16df65c4 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -19,9 +19,7 @@ mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials -mixed_numerics_exp_storage_with_trials = ( - sql_storage_fixtures.mixed_numerics_exp_storage_with_trials -) +mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials exp_data = sql_storage_fixtures.exp_data exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 685e92f7f9..8159043be1 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -22,32 +22,23 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) assert exp.objectives == exp_storage.opt_targets -def test_exp_data_root_env_config( - exp_storage: Storage.Experiment, exp_data: ExperimentData -) -> None: +def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: """Tests the root_env_config property of ExperimentData""" # pylint: disable=protected-access - assert exp_data.root_env_config == ( - exp_storage._root_env_config, - exp_storage._git_repo, - exp_storage._git_commit, - ) + assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) -def test_exp_trial_data_objectives( - storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups -) -> None: +def test_exp_trial_data_objectives(storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Start a new trial and check the storage for the trial data. """ - trial_opt_new = exp_storage.new_trial( - tunable_groups, - config={ - "opt_target": "some-other-target", - "opt_direction": "max", - }, - ) + trial_opt_new = exp_storage.new_trial(tunable_groups, config={ + "opt_target": "some-other-target", + "opt_direction": "max", + }) assert trial_opt_new.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_new.trial_id, @@ -55,13 +46,10 @@ def test_exp_trial_data_objectives( "opt_direction": "max", } - trial_opt_old = exp_storage.new_trial( - tunable_groups, - config={ - "opt_target": "back-compat", - # "opt_direction": "max", # missing - }, - ) + trial_opt_old = exp_storage.new_trial(tunable_groups, config={ + "opt_target": "back-compat", + # "opt_direction": "max", # missing + }) assert trial_opt_old.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_old.trial_id, @@ -86,14 +74,9 @@ def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGr assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT assert len(results_df["trial_id"].unique()) == expected_trials_count obj_target = next(iter(exp_data.objectives)) - assert ( - len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count - ) + assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count (tunable, _covariant_group) = next(iter(tunable_groups)) - assert ( - len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) - == expected_trials_count - ) + assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: @@ -133,15 +116,13 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [ - config_trial_group.tunable_config_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [config_trial_group.tunable_config_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT + 1)) # And the tunable_config_trial_group_id should also match the minimum trial_id. - assert [ - config_trial_group.tunable_config_trial_group_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) + assert [config_trial_group.tunable_config_trial_group_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: @@ -149,9 +130,9 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list( - range(1, CONFIG_COUNT + 1) - ) + assert [config.tunable_config_id + for config in exp_data.tunable_configs.values() + ] == list(range(1, CONFIG_COUNT + 1)) def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index 292996db4f..d0a5edc694 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -37,9 +37,9 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial and check that it is pending. """ @@ -50,14 +50,14 @@ def test_exp_trial_pending( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_many( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_pending_many(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start THREE trials and check that both are pending. """ - config1 = tunable_groups.copy().assign({"idle": "mwait"}) - config2 = tunable_groups.copy().assign({"idle": "noidle"}) + config1 = tunable_groups.copy().assign({'idle': 'mwait'}) + config2 = tunable_groups.copy().assign({'idle': 'noidle'}) trial_ids = { exp_storage.new_trial(config1).trial_id, exp_storage.new_trial(config2).trial_id, @@ -72,9 +72,9 @@ def test_exp_trial_pending_many( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_fail( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial, fail it, and and check that it is NOT pending. """ @@ -85,9 +85,9 @@ def test_exp_trial_pending_fail( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_success( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_success(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start a trial, finish it successfully, and and check that it is NOT pending. """ @@ -98,9 +98,9 @@ def test_exp_trial_success( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_categ( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_update_categ(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Update the trial with multiple metrics, some of which are categorical. """ @@ -108,23 +108,21 @@ def test_exp_trial_update_categ( trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) assert exp_storage.load() == ( [trial.trial_id], - [ - { - "idle": "halt", - "kernel_sched_latency_ns": "2000000", - "kernel_sched_migration_cost_ns": "-1", - "vmSize": "Standard_B4ms", - } - ], + [{ + 'idle': 'halt', + 'kernel_sched_latency_ns': '2000000', + 'kernel_sched_migration_cost_ns': '-1', + 'vmSize': 'Standard_B4ms' + }], [{"score": "99.9", "benchmark": "test"}], - [Status.SUCCEEDED], + [Status.SUCCEEDED] ) @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_twice( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_update_twice(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Update the trial status twice and receive an error. """ @@ -135,9 +133,9 @@ def test_exp_trial_update_twice( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_3( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] -) -> None: +def test_exp_trial_pending_3(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo]) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. Check that one is still pending another one can be loaded into the optimizer. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index f9072a2b8d..7e346a5ccc 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -36,7 +36,7 @@ def storage() -> SqlStorage: "drivername": "sqlite", "database": ":memory:", # "database": "mlos_bench.pytest.db", - }, + } ) @@ -106,9 +106,7 @@ def mixed_numerics_exp_storage( assert not exp._in_context -def _dummy_run_exp( - exp: SqlStorage.Experiment, tunable_name: Optional[str] -) -> SqlStorage.Experiment: +def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: """ Generates data by doing a simulated run of the given experiment. """ @@ -121,30 +119,24 @@ def _dummy_run_exp( (tunable_min, tunable_max) = tunable.range tunable_range = tunable_max - tunable_min rand_seed(SEED) - opt = MockOptimizer( - tunables=exp.tunables, - config={ - "seed": SEED, - # This should be the default, so we leave it omitted for now to test the default. - # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) - # "start_with_defaults": True, - }, - ) + opt = MockOptimizer(tunables=exp.tunables, config={ + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) + # "start_with_defaults": True, + }) assert opt.start_with_defaults for config_i in range(CONFIG_COUNT): tunables = opt.suggest() for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT): - trial = exp.new_trial( - tunables=tunables.copy(), - config={ - "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(exp.opt_targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - }, - }, - ) + trial = exp.new_trial(tunables=tunables.copy(), config={ + "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(exp.opt_targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + } + }) if exp.tunables: assert trial.tunable_config_id == config_i + 1 else: @@ -155,23 +147,14 @@ def _dummy_run_exp( else: tunable_value_norm = 0 timestamp = datetime.now(UTC) - trial.update_telemetry( - status=Status.RUNNING, - timestamp=timestamp, - metrics=[ - (timestamp, "some-metric", tunable_value_norm + random() / 100), - ], - ) - trial.update( - Status.SUCCEEDED, - timestamp, - metrics={ - # Give some variance on the score. - # And some influence from the tunable value. - "score": tunable_value_norm - + random() / 100 - }, - ) + trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[ + (timestamp, "some-metric", tunable_value_norm + random() / 100), + ]) + trial.update(Status.SUCCEEDED, timestamp, metrics={ + # Give some variance on the score. + # And some influence from the tunable value. + "score": tunable_value_norm + random() / 100 + }) return exp @@ -184,9 +167,7 @@ def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Ex @pytest.fixture -def exp_no_tunables_storage_with_trials( - exp_no_tunables_storage: SqlStorage.Experiment, -) -> SqlStorage.Experiment: +def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -195,9 +176,7 @@ def exp_no_tunables_storage_with_trials( @pytest.fixture -def mixed_numerics_exp_storage_with_trials( - mixed_numerics_exp_storage: SqlStorage.Experiment, -) -> SqlStorage.Experiment: +def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: """ Test fixture for Experiment using in-memory SQLite3 storage. """ @@ -206,9 +185,7 @@ def mixed_numerics_exp_storage_with_trials( @pytest.fixture -def exp_data( - storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData. """ @@ -216,9 +193,7 @@ def exp_data( @pytest.fixture -def exp_no_tunables_data( - storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData with no tunable configs. """ @@ -226,9 +201,7 @@ def exp_no_tunables_data( @pytest.fixture -def mixed_numerics_exp_data( - storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment -) -> ExperimentData: +def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: """ Test fixture for ExperimentData with mixed numerical tunable types. """ diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index 088daca84a..ba965ed3c6 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -13,7 +13,8 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Schedule a trial and check that it is pending and has the right configuration. """ @@ -30,12 +31,13 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: Tuna } -def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: +def test_exp_trial_configs(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Start multiple trials with two different configs and check that we store only two config objects in the DB. """ - config1 = tunable_groups.copy().assign({"idle": "mwait"}) + config1 = tunable_groups.copy().assign({'idle': 'mwait'}) trials1 = [ exp_storage.new_trial(config1), exp_storage.new_trial(config1), @@ -44,7 +46,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: Tuna assert trials1[0].tunable_config_id == trials1[1].tunable_config_id assert trials1[0].tunable_config_id == trials1[2].tunable_config_id - config2 = tunable_groups.copy().assign({"idle": "halt"}) + config2 = tunable_groups.copy().assign({'idle': 'halt'}) trials2 = [ exp_storage.new_trial(config2), exp_storage.new_trial(config2), diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index debd983cf0..04f4f18ae3 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -22,7 +22,8 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: return set(t.trial_id for t in trials) -def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: +def test_schedule_trial(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups) -> None: """ Schedule several trials for future execution and retrieve them later at certain timestamps. """ @@ -43,14 +44,16 @@ def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: Tunable # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: - pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) + pending_ids = _trial_ids( + exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, } # Get trials scheduled to run within the next 1 hour: - pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) + pending_ids = _trial_ids( + exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -59,8 +62,7 @@ def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: Tunable # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -82,8 +84,7 @@ def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: Tunable # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) assert pending_ids == { trial_1h.trial_id, trial_2h.trial_id, @@ -91,8 +92,7 @@ def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: Tunable # Get trials scheduled to run OR running within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) - ) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -114,9 +114,7 @@ def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: Tunable assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] # Get only trials completed after trial_now2: - (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( - last_trial_id=trial_now2.trial_id - ) + (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) assert trial_ids == [trial_1h.trial_id] assert len(trial_configs) == len(trial_scores) == 1 assert trial_status == [Status.SUCCEEDED] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index 449b564395..855c6cd861 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -31,21 +31,18 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st """ timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) - return sorted( - [ - (timestamp1, "cpu_load", 10.1), - (timestamp1, "memory", 20), - (timestamp1, "setup", "prod"), - (timestamp2, "cpu_load", 30.1), - (timestamp2, "memory", 40), - (timestamp2, "setup", "prod"), - ] - ) + return sorted([ + (timestamp1, "cpu_load", 10.1), + (timestamp1, "memory", 20), + (timestamp1, "setup", "prod"), + (timestamp2, "cpu_load", 30.1), + (timestamp2, "memory", 40), + (timestamp2, "setup", "prod"), + ]) -def _telemetry_str( - data: List[Tuple[datetime, str, Any]] -) -> List[Tuple[datetime, str, Optional[str]]]: +def _telemetry_str(data: List[Tuple[datetime, str, Any]] + ) -> List[Tuple[datetime, str, Optional[str]]]: """ Convert telemetry values to strings. """ @@ -54,12 +51,10 @@ def _telemetry_str( @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry( - storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo], -) -> None: +def test_update_telemetry(storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() and load_telemetry() methods work. """ @@ -78,11 +73,9 @@ def test_update_telemetry( @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry_twice( - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo], -) -> None: +def test_update_telemetry_twice(exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo]) -> None: """ Make sure update_telemetry() call is idempotent. """ diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 251c50b241..3b57222822 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -10,9 +10,8 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_trial_data_tunable_config_data( - exp_data: ExperimentData, tunable_groups: TunableGroups -) -> None: +def test_trial_data_tunable_config_data(exp_data: ExperimentData, + tunable_groups: TunableGroups) -> None: """ Check expected return values for TunableConfigData. """ @@ -30,12 +29,12 @@ def test_trial_metadata(exp_data: ExperimentData) -> None: """ Check expected return values for TunableConfigData metadata. """ - assert exp_data.objectives == {"score": "min"} - for trial_id, trial in exp_data.trials.items(): + assert exp_data.objectives == {'score': 'min'} + for (trial_id, trial) in exp_data.trials.items(): assert trial.metadata_dict == { - "opt_target_0": "score", - "opt_direction_0": "min", - "trial_number": trial_id, + 'opt_target_0': 'score', + 'opt_direction_0': 'min', + 'trial_number': trial_id, } @@ -49,13 +48,13 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, mixed_numerics_tunable_groups: TunableGroups -) -> None: + mixed_numerics_exp_data: ExperimentData, + mixed_numerics_tunable_groups: TunableGroups) -> None: """ Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) config = trial.tunable_config.config_dict - for tunable, _group in mixed_numerics_tunable_groups: + for (tunable, _group) in mixed_numerics_tunable_groups: assert isinstance(config[tunable.name], tunable.dtype) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index fd57d07635..d08b26e92d 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -16,15 +16,10 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: trial_id = 1 trial = exp_data.trials[trial_id] tunable_config_trial_group = trial.tunable_config_trial_group - assert ( - tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id - ) + assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id assert tunable_config_trial_group.tunable_config == trial.tunable_config - assert ( - tunable_config_trial_group - == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group - ) + assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: @@ -54,9 +49,7 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) # And so on ... -def test_tunable_config_trial_group_results_df( - exp_data: ExperimentData, tunable_groups: TunableGroups -) -> None: +def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 expected_group_id = 4 @@ -65,14 +58,9 @@ def test_tunable_config_trial_group_results_df( # We shouldn't have the results for the other configs, just this one. expected_count = CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_count - assert ( - len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count - ) + assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 - assert ( - len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) - == expected_count - ) + assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 assert len(results_df["trial_id"].unique()) == expected_count obj_target = next(iter(exp_data.objectives)) @@ -88,14 +76,8 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] trials = tunable_config_trial_group.trials assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT - assert all( - trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id - for trial in trials.values() - ) - assert all( - trial.tunable_config_id == tunable_config_id - for trial in tunable_config_trial_group.trials.values() - ) - assert ( - exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] - ) + assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id + for trial in trials.values()) + assert all(trial.tunable_config_id == tunable_config_id + for trial in tunable_config_trial_group.trials.values()) + assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index c3acd9d243..fa947610da 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -24,7 +24,7 @@ ] -@pytest.mark.skipif(sys.platform == "win32", reason="TZ environment variable is a UNIXism") +@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism") @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: @@ -45,6 +45,4 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: if cmd.returncode != 0: print(cmd.stdout.decode()) print(cmd.stderr.decode()) - raise AssertionError( - f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'" - ) + raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'") diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 8329b51bd0..822547b1da 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -119,26 +119,24 @@ def mixed_numerics_tunable_groups() -> TunableGroups: tunable_groups : TunableGroups A new TunableGroups object for testing. """ - tunables = TunableGroups( - { - "mix-numerics": { - "cost": 1, - "params": { - "int": { - "description": "An integer", - "type": "int", - "default": 0, - "range": [0, 100], - }, - "float": { - "description": "A float", - "type": "float", - "default": 0, - "range": [0, 1], - }, + tunables = TunableGroups({ + "mix-numerics": { + "cost": 1, + "params": { + "int": { + "description": "An integer", + "type": "int", + "default": 0, + "range": [0, 100], }, - }, - } - ) + "float": { + "description": "A float", + "type": "float", + "default": 0, + "range": [0, 1], + }, + } + }, + }) tunables.reset() return tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index 878471b59e..95de20d9b8 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -25,15 +25,12 @@ def tunable_categorical() -> Tunable: tunable : Tunable An instance of a categorical Tunable. """ - return Tunable( - "vmSize", - { - "description": "Azure VM size", - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - }, - ) + return Tunable("vmSize", { + "description": "Azure VM size", + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] + }) @pytest.fixture @@ -46,16 +43,13 @@ def tunable_int() -> Tunable: tunable : Tunable An instance of an integer Tunable. """ - return Tunable( - "kernel_sched_migration_cost_ns", - { - "description": "Cost of migrating the thread to another core", - "type": "int", - "default": 40000, - "range": [0, 500000], - "special": [-1], # Special value outside of the range - }, - ) + return Tunable("kernel_sched_migration_cost_ns", { + "description": "Cost of migrating the thread to another core", + "type": "int", + "default": 40000, + "range": [0, 500000], + "special": [-1] # Special value outside of the range + }) @pytest.fixture @@ -68,12 +62,9 @@ def tunable_float() -> Tunable: tunable : Tunable An instance of a float Tunable. """ - return Tunable( - "chaos_monkey_prob", - { - "description": "Probability of spontaneous VM shutdown", - "type": "float", - "default": 0.01, - "range": [0, 1], - }, - ) + return Tunable("chaos_monkey_prob", { + "description": "Probability of spontaneous VM shutdown", + "type": "float", + "default": 0.01, + "range": [0, 1] + }) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index e8b3e6b4cc..0e910f3761 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -38,7 +38,7 @@ def test_tunable_categorical_types() -> None: "values": ["a", "b", "c"], "default": "a", }, - }, + } } } tunable_groups = TunableGroups(tunable_params) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index c42ae21676..58bb0368b1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -14,7 +14,6 @@ # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py - def test_tunable_int_size_props() -> None: """Test tunable int size properties""" tunable = Tunable( @@ -23,8 +22,7 @@ def test_tunable_int_size_props() -> None: "type": "int", "range": [1, 5], "default": 3, - }, - ) + }) assert tunable.span == 4 assert tunable.cardinality == 5 expected = [1, 2, 3, 4, 5] @@ -40,8 +38,7 @@ def test_tunable_float_size_props() -> None: "type": "float", "range": [1.5, 5], "default": 3, - }, - ) + }) assert tunable.span == 3.5 assert tunable.cardinality == np.inf assert tunable.quantized_values is None @@ -56,8 +53,7 @@ def test_tunable_categorical_size_props() -> None: "type": "categorical", "values": ["a", "b", "c"], "default": "a", - }, - ) + }) with pytest.raises(AssertionError): _ = tunable.span assert tunable.cardinality == 3 @@ -70,8 +66,12 @@ def test_tunable_quantized_int_size_props() -> None: """Test quantized tunable int size properties""" tunable = Tunable( name="test", - config={"type": "int", "range": [100, 1000], "default": 100, "quantization": 100}, - ) + config={ + "type": "int", + "range": [100, 1000], + "default": 100, + "quantization": 100 + }) assert tunable.span == 900 assert tunable.cardinality == 10 expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] @@ -82,8 +82,13 @@ def test_tunable_quantized_int_size_props() -> None: def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties""" tunable = Tunable( - name="test", config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1} - ) + name="test", + config={ + "type": "float", + "range": [0, 1], + "default": 0, + "quantization": .1 + }) assert tunable.span == 1 assert tunable.cardinality == 11 expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 407998b3a4..6a91b14016 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -28,7 +28,7 @@ def test_tunable_int_name_lt(tunable_int: Tunable) -> None: Tests that the __lt__ operator works as expected. """ tunable_int_2 = tunable_int.copy() - tunable_int_2._name = "aaa" # pylint: disable=protected-access + tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int @@ -38,8 +38,7 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: """ tunable_categorical_2 = tunable_categorical.copy() new_value = [ - x - for x in tunable_categorical.categories + x for x in tunable_categorical.categories if x != tunable_categorical.category and x is not None ][0] assert tunable_categorical.category is not None @@ -60,7 +59,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - }, + } ) tunable_dog = Tunable( name="same-name", @@ -68,7 +67,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": [None, "doggo"], "default": None, - }, + } ) assert tunable_dog < tunable_cat @@ -83,7 +82,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - }, + } ) tunable_int = Tunable( name="same-name", @@ -91,7 +90,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "int", "range": [1, 3], "default": 2, - }, + } ) assert tunable_cat < tunable_int @@ -102,7 +101,7 @@ def test_tunable_lt_different_object(tunable_int: Tunable) -> None: """ assert (tunable_int < "foo") is False with pytest.raises(TypeError): - assert "foo" < tunable_int # type: ignore[operator] + assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index 980fda06a4..f2da3ba60e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -18,7 +18,7 @@ def test_tunable_name() -> None: """ with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names - Tunable(name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0}) + Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) def test_categorical_required_params() -> None: @@ -34,7 +34,7 @@ def test_categorical_required_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_weights() -> None: @@ -50,7 +50,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.weights == [25, 25, 50] @@ -68,7 +68,7 @@ def test_categorical_weights_wrong_count() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_weights_wrong_values() -> None: @@ -85,7 +85,7 @@ def test_categorical_weights_wrong_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_wrong_params() -> None: @@ -102,7 +102,7 @@ def test_categorical_wrong_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_disallow_special_values() -> None: @@ -119,7 +119,7 @@ def test_categorical_disallow_special_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_categorical_tunable_disallow_repeats() -> None: @@ -127,14 +127,11 @@ def test_categorical_tunable_disallow_repeats() -> None: Disallow duplicate values in categorical tunables. """ with pytest.raises(ValueError): - Tunable( - name="test", - config={ - "type": "categorical", - "values": ["foo", "bar", "foo"], - "default": "foo", - }, - ) + Tunable(name='test', config={ + "type": "categorical", + "values": ["foo", "bar", "foo"], + "default": "foo", + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -143,14 +140,11 @@ def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeN Disallow null values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "default": None, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "default": None, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -159,14 +153,11 @@ def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeN Disallow out of range values as default for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "default": 11, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "default": 11, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -175,15 +166,12 @@ def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> N Disallow values param for numerical tunables. """ with pytest.raises(ValueError): - Tunable( - name=f"test_{tunable_type}", - config={ - "type": tunable_type, - "range": [0, 10], - "values": ["foo", "bar"], - "default": 0, - }, - ) + Tunable(name=f'test_{tunable_type}', config={ + "type": tunable_type, + "range": [0, 10], + "values": ["foo", "bar"], + "default": 0, + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -200,7 +188,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -217,7 +205,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(AssertionError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -234,7 +222,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f"test_{tunable_type}", config=config) + Tunable(name=f'test_{tunable_type}', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -253,7 +241,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.special == [0] assert tunable.weights == [0.1] assert tunable.range_weight == 0.9 @@ -273,7 +261,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.quantization == 10 assert not tunable.is_log @@ -292,7 +280,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.is_log @@ -311,7 +299,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -331,7 +319,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) - tunable = Tunable(name="test", config=config) + tunable = Tunable(name='test', config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok assert tunable.range_weight == 90 @@ -354,7 +342,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -373,7 +361,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -392,7 +380,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -410,7 +398,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -430,7 +418,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -448,7 +436,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test", config=config) + Tunable(name='test', config=config) def test_bad_type() -> None: @@ -464,4 +452,4 @@ def test_bad_type() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name="test_bad_type", config=config) + Tunable(name='test_bad_type', config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index e8817319ab..deffcb6a46 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -17,15 +17,14 @@ def test_categorical_distribution() -> None: Try to instantiate a categorical tunable with distribution specified. """ with pytest.raises(ValueError): - Tunable( - name="test", - config={ - "type": "categorical", - "values": ["foo", "bar", "baz"], - "distribution": {"type": "uniform"}, - "default": "foo", + Tunable(name='test', config={ + "type": "categorical", + "values": ["foo", "bar", "baz"], + "distribution": { + "type": "uniform" }, - ) + "default": "foo" + }) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -33,15 +32,14 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N """ Create a numeric Tunable with explicit uniform distribution. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "uniform"}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "uniform" }, - ) + "default": 0 + }) assert tunable.is_numerical assert tunable.distribution == "uniform" assert not tunable.distribution_params @@ -52,15 +50,18 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No """ Create a numeric Tunable with explicit Gaussian distribution specified. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "normal", + "params": { + "mu": 0, + "sigma": 1.0 + } }, - ) + "default": 0 + }) assert tunable.distribution == "normal" assert tunable.distribution_params == {"mu": 0, "sigma": 1.0} @@ -70,15 +71,18 @@ def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None """ Create a numeric Tunable with explicit Beta distribution specified. """ - tunable = Tunable( - name="test", - config={ - "type": tunable_type, - "range": [0, 10], - "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}}, - "default": 0, + tunable = Tunable(name="test", config={ + "type": tunable_type, + "range": [0, 10], + "distribution": { + "type": "beta", + "params": { + "alpha": 2, + "beta": 5 + } }, - ) + "default": 0 + }) assert tunable.distribution == "beta" assert tunable.distribution_params == {"alpha": 2, "beta": 5} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index d9b209cf4f..c6fb5670f0 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -10,9 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunable_group_indexing( - tunable_groups: TunableGroups, tunable_categorical: Tunable -) -> None: +def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: """ Check that various types of indexing work for the tunable group. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 186de4acfa..55a485e951 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -14,4 +14,4 @@ def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: Check that the subgroup() method returns only a selection of tunable parameters. """ tunables = tunable_groups.subgroup(["provision"]) - assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"} + assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 0dfbdd2acd..73e3a12caa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -36,39 +36,37 @@ @pytest.mark.parametrize("param_type", ["int", "float"]) -@pytest.mark.parametrize( - "distr_name,distr_params", - [ - ("normal", {"mu": 0.0, "sigma": 1.0}), - ("beta", {"alpha": 2, "beta": 5}), - ("uniform", {}), - ], -) -def test_convert_numerical_distributions( - param_type: str, distr_name: DistributionName, distr_params: dict -) -> None: +@pytest.mark.parametrize("distr_name,distr_params", [ + ("normal", {"mu": 0.0, "sigma": 1.0}), + ("beta", {"alpha": 2, "beta": 5}), + ("uniform", {}), +]) +def test_convert_numerical_distributions(param_type: str, + distr_name: DistributionName, + distr_params: dict) -> None: """ Convert a numerical Tunable with explicit distribution to ConfigSpace. """ tunable_name = "x" - tunable_groups = TunableGroups( - { - "tunable_group": { - "cost": 1, - "params": { - tunable_name: { - "type": param_type, - "range": [0, 100], - "special": [-1, 0], - "special_weights": [0.1, 0.2], - "range_weight": 0.7, - "distribution": {"type": distr_name, "params": distr_params}, - "default": 0, - } - }, + tunable_groups = TunableGroups({ + "tunable_group": { + "cost": 1, + "params": { + tunable_name: { + "type": param_type, + "range": [0, 100], + "special": [-1, 0], + "special_weights": [0.1, 0.2], + "range_weight": 0.7, + "distribution": { + "type": distr_name, + "params": distr_params + }, + "default": 0 + } } } - ) + }) (tunable, _group) = tunable_groups.get_tunable(tunable_name) assert tunable.distribution == distr_name @@ -84,5 +82,5 @@ def test_convert_numerical_distributions( cs_param = space[tunable_name] assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) - for key, val in distr_params.items(): + for (key, val) in distr_params.items(): assert getattr(cs_param, key) == val diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 39bd41e282..78e91fd25e 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -38,23 +38,17 @@ def configuration_space() -> ConfigurationSpace: configuration_space : ConfigurationSpace A new ConfigurationSpace object for testing. """ - (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = ( - special_param_names("kernel_sched_migration_cost_ns") - ) - - spaces = ConfigurationSpace( - space={ - "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - "idle": ["halt", "mwait", "noidle"], - "kernel_sched_migration_cost_ns": (0, 500000), - kernel_sched_migration_cost_ns_special: [-1, 0], - kernel_sched_migration_cost_ns_type: [ - TunableValueKind.SPECIAL, - TunableValueKind.RANGE, - ], - "kernel_sched_latency_ns": (0, 1000000000), - } - ) + (kernel_sched_migration_cost_ns_special, + kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns") + + spaces = ConfigurationSpace(space={ + "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + "idle": ["halt", "mwait", "noidle"], + "kernel_sched_migration_cost_ns": (0, 500000), + kernel_sched_migration_cost_ns_special: [-1, 0], + kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE], + "kernel_sched_latency_ns": (0, 1000000000), + }) # NOTE: FLAML requires distribution to be uniform spaces["vmSize"].default_value = "Standard_B4ms" @@ -66,25 +60,18 @@ def configuration_space() -> ConfigurationSpace: spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5) spaces["kernel_sched_latency_ns"].default_value = 2000000 - spaces.add_condition( - EqualsCondition( - spaces[kernel_sched_migration_cost_ns_special], - spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.SPECIAL, - ) - ) - spaces.add_condition( - EqualsCondition( - spaces["kernel_sched_migration_cost_ns"], - spaces[kernel_sched_migration_cost_ns_type], - TunableValueKind.RANGE, - ) - ) + spaces.add_condition(EqualsCondition( + spaces[kernel_sched_migration_cost_ns_special], + spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL)) + spaces.add_condition(EqualsCondition( + spaces["kernel_sched_migration_cost_ns"], + spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE)) return spaces -def _cmp_tunable_hyperparameter_categorical(tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_categorical( + tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if categorical Tunable and ConfigSpace Hyperparameter actually match. """ @@ -94,7 +81,8 @@ def _cmp_tunable_hyperparameter_categorical(tunable: Tunable, space: Configurati assert param.default_value == tunable.value -def _cmp_tunable_hyperparameter_numerical(tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_numerical( + tunable: Tunable, space: ConfigurationSpace) -> None: """ Check if integer Tunable and ConfigSpace Hyperparameter actually match. """ @@ -142,13 +130,12 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) - for tunable, _group in tunable_groups: + for (tunable, _group) in tunable_groups: _CMP_FUNC[tunable.type](tunable, space) def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace -) -> None: + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: """ Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index 2f7790602f..cbccd6bfe1 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -19,14 +19,12 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: that don't exist in the TunableGroups object. """ with pytest.raises(KeyError): - tunable_groups.assign( - { - "vmSize": "Standard_B2ms", - "idle": "mwait", - "UnknownParam_1": 1, - "UnknownParam_2": "invalid-value", - } - ) + tunable_groups.assign({ + "vmSize": "Standard_B2ms", + "idle": "mwait", + "UnknownParam_1": 1, + "UnknownParam_2": "invalid-value" + }) def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: @@ -108,7 +106,7 @@ def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: Check str to int coercion. """ tunable_int.value = "10" - assert tunable_int.value == 10 # type: ignore[comparison-overlap] + assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special @@ -117,7 +115,7 @@ def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: Check str to float coercion. """ tunable_float.value = "0.5" - assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special @@ -151,12 +149,12 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) - categorical_tunable = Tunable(name="categorical_test", config=config) + categorical_tunable = Tunable(name='categorical_test', config=config) assert categorical_tunable assert categorical_tunable.category == "foo" categorical_tunable.value = None assert categorical_tunable.value is None - assert categorical_tunable.value != "None" + assert categorical_tunable.value != 'None' assert categorical_tunable.category is None @@ -167,7 +165,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): - tunable_int.numerical_value = None # type: ignore[assignment] + tunable_int.numerical_value = None # type: ignore[assignment] def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: @@ -177,7 +175,7 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): - tunable_float.numerical_value = None # type: ignore[assignment] + tunable_float.numerical_value = None # type: ignore[assignment] def test_tunable_assign_special(tunable_int: Tunable) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index cb41f7f7d8..672b16ab73 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -17,44 +17,42 @@ def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: tunables within each covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: - tunables_other = TunableGroups( - { - "kernel": { - "cost": 1, - "params": { - "kernel_sched_latency_ns": { - "type": "int", - "default": 2000000, - "range": [0, 1000000000], - }, - "kernel_sched_migration_cost_ns": { - "type": "int", - "default": -1, - "range": [0, 500000], - "special": [-1], - }, + tunables_other = TunableGroups({ + "kernel": { + "cost": 1, + "params": { + "kernel_sched_latency_ns": { + "type": "int", + "default": 2000000, + "range": [0, 1000000000] }, - }, - "boot": { - "cost": 300, - "params": { - "idle": { - "type": "categorical", - "default": "halt", - "values": ["halt", "mwait", "noidle"], - } - }, - }, - "provision": { - "cost": 1000, - "params": { - "vmSize": { - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - } - }, - }, - } - ) + "kernel_sched_migration_cost_ns": { + "type": "int", + "default": -1, + "range": [0, 500000], + "special": [-1] + } + } + }, + "boot": { + "cost": 300, + "params": { + "idle": { + "type": "categorical", + "default": "halt", + "values": ["halt", "mwait", "noidle"] + } + } + }, + "provision": { + "cost": 1000, + "params": { + "vmSize": { + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] + } + } + }, + }) assert str(tunable_groups) == str(tunables_other) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 3433f4a735..4191f37d89 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -10,7 +10,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups __all__ = [ - "Tunable", - "TunableValue", - "TunableGroups", + 'Tunable', + 'TunableValue', + 'TunableGroups', ] diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index 797510a087..fee4fd5841 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -93,12 +93,10 @@ def __eq__(self, other: object) -> bool: return False # TODO: May need to provide logic to relax the equality check on the # tunables (e.g. "compatible" vs. "equal"). - return ( - self._name == other._name - and self._cost == other._cost - and self._is_updated == other._is_updated - and self._tunables == other._tunables - ) + return (self._name == other._name and + self._cost == other._cost and + self._is_updated == other._is_updated and + self._tunables == other._tunables) def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ @@ -236,11 +234,7 @@ def __contains__(self, tunable: Union[str, Tunable]) -> bool: def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value - def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] - ) -> TunableValue: - value: TunableValue = ( - tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value - ) + def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: + value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value self._is_updated |= self.get_tunable(tunable).update(value) return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index b2a465c71a..1ebd70dfa4 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -107,7 +107,7 @@ def __init__(self, name: str, config: TunableDict): config : dict Python dict that represents a Tunable (e.g., deserialized from JSON) """ - if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema + if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name self._type: TunableValueTypeName = config["type"] # required @@ -202,16 +202,10 @@ def _sanity_check_numerical(self) -> None: raise ValueError(f"Number of quantization points is <= 1: {self}") if self.dtype == float: if not isinstance(self._quantization, (float, int)): - raise ValueError( - f"Quantization of a float param should be a float or int: {self}" - ) + raise ValueError(f"Quantization of a float param should be a float or int: {self}") if self._quantization <= 0: raise ValueError(f"Number of quantization points is <= 0: {self}") - if self._distribution is not None and self._distribution not in { - "uniform", - "normal", - "beta", - }: + if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: raise ValueError(f"Invalid distribution: {self}") if self._distribution_params and self._distribution is None: raise ValueError(f"Must specify the distribution: {self}") @@ -236,9 +230,7 @@ def __repr__(self) -> str: """ # TODO? Add weights, specials, quantization, distribution? if self.is_categorical: - return ( - f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" - ) + return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" def __eq__(self, other: object) -> bool: @@ -259,12 +251,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Tunable): return False return bool( - self._name == other._name - and self._type == other._type - and self._current_value == other._current_value + self._name == other._name and + self._type == other._type and + self._current_value == other._current_value ) - def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements + def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements """ Compare the two Tunable objects. We mostly need this to create a canonical list of tunable objects when hashing a TunableGroup. @@ -344,21 +336,18 @@ def value(self, value: TunableValue) -> TunableValue: assert value is not None coerced_value = self.dtype(value) except Exception: - _LOG.error( - "Impossible conversion: %s %s <- %s %s", self._type, self._name, type(value), value - ) + _LOG.error("Impossible conversion: %s %s <- %s %s", + self._type, self._name, type(value), value) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: - _LOG.error( - "Loss of precision: %s %s <- %s %s", self._type, self._name, type(value), value - ) + _LOG.error("Loss of precision: %s %s <- %s %s", + self._type, self._name, type(value), value) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): - _LOG.error( - "Invalid assignment: %s %s <- %s %s", self._type, self._name, type(value), value - ) + _LOG.error("Invalid assignment: %s %s <- %s %s", + self._type, self._name, type(value), value) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") self._current_value = coerced_value @@ -414,10 +403,10 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: Return False if the tunable or value is categorical or None. """ return ( - isinstance(value, (float, int)) - and self.is_numerical - and self._range is not None - and bool(self._range[0] <= value <= self._range[1]) + isinstance(value, (float, int)) and + self.is_numerical and + self._range is not None and + bool(self._range[0] <= value <= self._range[1]) ) @property @@ -637,12 +626,10 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: # Be sure to return python types instead of numpy types. cardinality = self.cardinality assert isinstance(cardinality, int) - return ( - float(x) - for x in np.linspace( - start=num_range[0], stop=num_range[1], num=cardinality, endpoint=True - ) - ) + return (float(x) for x in np.linspace(start=num_range[0], + stop=num_range[1], + num=cardinality, + endpoint=True)) assert self.type == "int", f"Unhandled tunable type: {self}" return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 8fbaee878c..0bd58c8269 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -30,11 +30,9 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = ( - {} - ) # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) self._tunable_groups: Dict[str, CovariantTunableGroup] = {} - for name, group_config in config.items(): + for (name, group_config) in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) def __bool__(self) -> bool: @@ -83,15 +81,11 @@ def _add_group(self, group: CovariantTunableGroup) -> None: ---------- group : CovariantTunableGroup """ - assert ( - group.name not in self._tunable_groups - ), f"Duplicate covariant tunable group name {group.name} in {self}" + assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" self._tunable_groups[group.name] = group for tunable in group.get_tunables(): if tunable.name in self._index: - raise ValueError( - f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" - ) + raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") self._index[tunable.name] = group def merge(self, tunables: "TunableGroups") -> "TunableGroups": @@ -125,10 +119,8 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": # Check that there's no overlap in the tunables. # But allow for differing current values. if not self._tunable_groups[group.name].equals_defaults(group): - raise ValueError( - f"Overlapping covariant tunable group name {group.name} " - + "in {self._tunable_groups[group.name]} and {tunables}" - ) + raise ValueError(f"Overlapping covariant tunable group name {group.name} " + + "in {self._tunable_groups[group.name]} and {tunables}") return self def __repr__(self) -> str: @@ -140,15 +132,10 @@ def __repr__(self) -> str: string : str A human-readable version of the TunableGroups. """ - return ( - "{ " - + ", ".join( - f"{group.name}::{tunable}" - for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) - for tunable in sorted(group._tunables.values()) - ) - + " }" - ) + return "{ " + ", ".join( + f"{group.name}::{tunable}" + for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) + for tunable in sorted(group._tunables.values())) + " }" def __contains__(self, tunable: Union[str, Tunable]) -> bool: """ @@ -164,17 +151,13 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] - def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] - ) -> TunableValue: + def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: """ Update the current value of a single tunable parameter. """ # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable - value: TunableValue = ( - tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value - ) + value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value self._index[name][name] = value return self._index[name][name] @@ -249,11 +232,8 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values( - self, - group_names: Optional[Iterable[str]] = None, - into_params: Optional[Dict[str, TunableValue]] = None, - ) -> Dict[str, TunableValue]: + def get_param_values(self, group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -292,10 +272,8 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: is_updated : bool True if any of the specified tunable groups has been updated, False otherwise. """ - return any( - self._tunable_groups[name].is_updated() - for name in (group_names or self.get_covariant_group_names()) - ) + return any(self._tunable_groups[name].is_updated() + for name in (group_names or self.get_covariant_group_names())) def is_defaults(self) -> bool: """ @@ -321,7 +299,7 @@ def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "Tuna self : TunableGroups Self-reference for chaining. """ - for name in group_names or self.get_covariant_group_names(): + for name in (group_names or self.get_covariant_group_names()): self._tunable_groups[name].restore_defaults() return self @@ -339,7 +317,7 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": self : TunableGroups Self-reference for chaining. """ - for name in group_names or self.get_covariant_group_names(): + for name in (group_names or self.get_covariant_group_names()): self._tunable_groups[name].reset_is_updated() return self diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 619e712497..531988be97 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -71,9 +71,8 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> return dest -def merge_parameters( - *, dest: dict, source: Optional[dict] = None, required_keys: Optional[Iterable[str]] = None -) -> dict: +def merge_parameters(*, dest: dict, source: Optional[dict] = None, + required_keys: Optional[Iterable[str]] = None) -> dict: """ Merge the source config dict into the destination config. Pick from the source configs *ONLY* the keys that are already present @@ -133,9 +132,8 @@ def path_join(*args: str, abs_path: bool = False) -> str: return os.path.normpath(path).replace("\\", "/") -def prepare_class_load( - config: dict, global_config: Optional[Dict[str, Any]] = None -) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load(config: dict, + global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -157,9 +155,8 @@ def prepare_class_load( merge_parameters(dest=class_config, source=global_config) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug( - "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) - ) + _LOG.debug("Instantiating: %s with config:\n%s", + class_name, json.dumps(class_config, indent=2)) return (class_name, class_config) @@ -190,9 +187,8 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". -def instantiate_from_config( - base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any -) -> BaseTypeVar: +def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, + *args: Any, **kwargs: Any) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -242,8 +238,7 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s if missing_params: raise ValueError( "The following parameters must be provided in the configuration" - + f" or as command line arguments: {missing_params}" - ) + + f" or as command line arguments: {missing_params}") def get_git_info(path: str = __file__) -> Tuple[str, str, str]: @@ -262,14 +257,11 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]: """ dirname = os.path.dirname(path) git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True - ).strip() + ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() git_commit = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "HEAD"], text=True - ).strip() + ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True - ).strip() + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) return (git_repo, git_commit, rel_path.replace("\\", "/")) @@ -363,9 +355,7 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> raise ValueError(f"Invalid origin: {origin}") -def utcify_nullable_timestamp( - timestamp: Optional[datetime], *, origin: Literal["utc", "local"] -) -> Optional[datetime]: +def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: """ A nullable version of utcify_timestamp. """ @@ -377,9 +367,7 @@ def utcify_nullable_timestamp( _MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) -def datetime_parser( - datetime_col: pandas.Series, *, origin: Literal["utc", "local"] -) -> pandas.Series: +def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. @@ -413,7 +401,7 @@ def datetime_parser( new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) assert new_datetime_col.dt.tz is not None # And convert it to UTC. - new_datetime_col = new_datetime_col.dt.tz_convert("UTC") + new_datetime_col = new_datetime_col.dt.tz_convert('UTC') if new_datetime_col.isna().any(): raise ValueError(f"Invalid date format in the data: {datetime_col}") if new_datetime_col.le(_MIN_TS).any(): diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index f8acae8c02..96d3d2b6bf 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index fc29bfbcbb..27d844c35b 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -21,16 +21,15 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -48,68 +47,62 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } -extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass +extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass # Additional tools for extra functionality. - "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], - "ssh": ["asyncssh"], - "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], - "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], - "storage-sql-postgres": ["sqlalchemy", "psycopg2"], - "storage-sql-sqlite": [ - "sqlalchemy" - ], # sqlite3 comes with python, so we don't need to install it. + 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], + 'ssh': ['asyncssh'], + 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], + 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], + 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], + 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it. # Transitive extra_requires from mlos-core. - "flaml": ["flaml[blendsearch]"], - "smac": ["smac"], + 'flaml': ['flaml[blendsearch]'], + 'smac': ['smac'], } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", - "pytest-lazy-fixtures", - "pytest-docker", - "fasteners", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', + 'pytest-lazy-fixtures', + 'pytest-docker', + 'fasteners', ] setup( version=VERSION, install_requires=[ - "mlos-core==" + VERSION, - "requests", - "json5", - "jsonschema>=4.18.0", - "referencing>=0.29.1", + 'mlos-core==' + VERSION, + 'requests', + 'json5', + 'jsonschema>=4.18.0', 'referencing>=0.29.1', 'importlib_resources;python_version<"3.10"', - ] - + extra_requires[ - "storage-sql-sqlite" - ], # NOTE: For now sqlite is a fallback storage backend, so we always install it. + ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it. extras_require=extra_requires, - **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_bench"), + **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'), ) diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index b3e248e407..086002af62 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -18,12 +18,12 @@ from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ - "SpaceAdapterType", - "OptimizerFactory", - "BaseOptimizer", - "RandomOptimizer", - "FlamlOptimizer", - "SmacOptimizer", + 'SpaceAdapterType', + 'OptimizerFactory', + 'BaseOptimizer', + 'RandomOptimizer', + 'FlamlOptimizer', + 'SmacOptimizer', ] @@ -45,7 +45,7 @@ class OptimizerType(Enum): # ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) # To address this, we add a test for complete coverage of the enum. ConcreteOptimizer = TypeVar( - "ConcreteOptimizer", + 'ConcreteOptimizer', RandomOptimizer, FlamlOptimizer, SmacOptimizer, @@ -60,15 +60,13 @@ class OptimizerFactory: # pylint: disable=too-few-public-methods @staticmethod - def create( - *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, - optimizer_kwargs: Optional[dict] = None, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None, - ) -> ConcreteOptimizer: # type: ignore[type-var] + def create(*, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, + optimizer_kwargs: Optional[dict] = None, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. @@ -109,7 +107,7 @@ def create( parameter_space=parameter_space, optimization_targets=optimization_targets, space_adapter=space_adapter, - **optimizer_kwargs, + **optimizer_kwargs ) return optimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index d4f59dfa52..5f32219988 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -12,6 +12,6 @@ from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer __all__ = [ - "BaseBayesianOptimizer", - "SmacOptimizer", + 'BaseBayesianOptimizer', + 'SmacOptimizer', ] diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 9d3bcabcb2..76ff0d9b3a 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -19,9 +19,8 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): """Abstract base class defining the interface for Bayesian optimization.""" @abstractmethod - def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def surrogate_predict(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None) -> npt.NDArray: """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). Parameters @@ -32,12 +31,11 @@ def surrogate_predict( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def acquisition_function(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None) -> npt.NDArray: """Invokes the acquisition function from this Bayesian optimizer for the given configuration. Parameters @@ -48,4 +46,4 @@ def acquisition_function( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 5784a42f12..9d8d2a0347 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -29,22 +29,19 @@ class SmacOptimizer(BaseBayesianOptimizer): Wrapper class for SMAC based Bayesian optimization. """ - def __init__( - self, - *, # pylint: disable=too-many-locals,too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - seed: Optional[int] = 0, - run_name: Optional[str] = None, - output_directory: Optional[str] = None, - max_trials: int = 100, - n_random_init: Optional[int] = None, - max_ratio: Optional[float] = None, - use_default_config: bool = False, - n_random_probability: float = 0.1, - ): + def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + seed: Optional[int] = 0, + run_name: Optional[str] = None, + output_directory: Optional[str] = None, + max_trials: int = 100, + n_random_init: Optional[int] = None, + max_ratio: Optional[float] = None, + use_default_config: bool = False, + n_random_probability: float = 0.1): """ Instantiate a new SMAC optimizer wrapper. @@ -127,9 +124,7 @@ def __init__( if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory( - ignore_cleanup_errors=True - ) # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10 except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -151,12 +146,8 @@ def __init__( seed=seed or -1, # if -1, SMAC will generate a random seed internally n_workers=1, # Use a single thread for evaluating trials ) - intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( - scenario, max_config_calls=1 - ) - config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( - scenario, retrain_after=1 - ) + intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1) + config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1) # TODO: When bulk registering prior configs to rewarm the optimizer, # there is a way to inform SMAC's initial design that we have @@ -167,27 +158,27 @@ def __init__( # See Also: #488 initial_design_args: Dict[str, Union[list, int, float, Scenario]] = { - "scenario": scenario, + 'scenario': scenario, # Workaround a bug in SMAC that sets a default arg to a mutable # value that can cause issues when multiple optimizers are # instantiated with the use_default_config option within the same # process that use different ConfigSpaces so that the second # receives the default config from both as an additional config. - "additional_configs": [], + 'additional_configs': [] } if n_random_init is not None: - initial_design_args["n_configs"] = n_random_init + initial_design_args['n_configs'] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - "Number of random initial configs (%d) is " - + "greater than 25%% of max_trials (%d). " - + "Consider setting max_ratio to avoid SMAC overriding n_random_init.", + 'Number of random initial configs (%d) is ' + + 'greater than 25%% of max_trials (%d). ' + + 'Consider setting max_ratio to avoid SMAC overriding n_random_init.', n_random_init, max_trials, ) if max_ratio is not None: assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0 - initial_design_args["max_ratio"] = max_ratio + initial_design_args['max_ratio'] = max_ratio # Use the default InitialDesign from SMAC. # (currently SBOL instead of LatinHypercube due to better uniformity @@ -199,9 +190,7 @@ def __init__( # design when generated a random_design for itself via the # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 - random_design = ProbabilityRandomDesign( - probability=n_random_probability, seed=scenario.seed - ) + random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed) self.base_optimizer = Optimizer_Smac( scenario, @@ -211,8 +200,7 @@ def __init__( random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights - ), + scenario, objective_weights=self._objective_weights), overwrite=True, logging_level=False, # Use the existing logger ) @@ -253,16 +241,10 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None """ # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 - raise RuntimeError("This function should never be called.") - - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + raise RuntimeError('This function should never be called.') + + def _register(self, *, configs: pd.DataFrame, + scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -289,22 +271,17 @@ def _register( warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) # Register each trial (one-by-one) - for config, (_i, score) in zip( - self._to_configspace_configs(configs=configs), scores.iterrows() - ): + for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()): # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) - ) + config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed)) value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) self.base_optimizer.tell(info, value, save=False) # Save optimizer once we register all configs self.base_optimizer.optimizer.save() - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -333,23 +310,15 @@ def _suggest( self.optimizer_parameter_space.check_configuration(trial.config) assert trial.config.config_space == self.optimizer_parameter_space self.trial_info_map[trial.config] = trial - config_df = pd.DataFrame( - [trial.config], columns=list(self.optimizer_parameter_space.keys()) - ) + config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys())) return config_df, None - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() - def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: from smac.utils.configspace import ( convert_configurations_to_array, # pylint: disable=import-outside-toplevel ) @@ -362,23 +331,16 @@ def surrogate_predict( # pylint: disable=protected-access if len(self._observations) <= self.base_optimizer._initial_design._n_configs: raise RuntimeError( - "Surrogate model can make predictions *only* after all initial points have been evaluated " - + f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" - ) + 'Surrogate model can make predictions *only* after all initial points have been evaluated ' + + f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}') if self.base_optimizer._config_selector._model is None: - raise RuntimeError("Surrogate model is not yet trained") + raise RuntimeError('Surrogate model is not yet trained') - config_array: npt.NDArray = convert_configurations_to_array( - self._to_configspace_configs(configs=configs) - ) + config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs)) mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array) - return mean_predictions.reshape( - -1, - ) + return mean_predictions.reshape(-1,) - def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None - ) -> npt.NDArray: + def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if self._space_adapter: @@ -386,15 +348,13 @@ def acquisition_function( # pylint: disable=protected-access if self.base_optimizer._config_selector._acquisition_function is None: - raise RuntimeError("Acquisition function is not yet initialized") + raise RuntimeError('Acquisition function is not yet initialized') cs_configs: list = self._to_configspace_configs(configs=configs) - return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape( - -1, - ) + return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,) def cleanup(self) -> None: - if hasattr(self, "_temp_output_directory") and self._temp_output_directory is not None: + if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None: self._temp_output_directory.cleanup() self._temp_output_directory = None @@ -413,5 +373,5 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. """ return [ ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) - for (_, config) in configs.astype("O").iterrows() + for (_, config) in configs.astype('O').iterrows() ] diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 2df19b8eb2..273c89eecc 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -33,16 +33,13 @@ class FlamlOptimizer(BaseOptimizer): # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" - def __init__( - self, - *, # pylint: disable=too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - low_cost_partial_config: Optional[dict] = None, - seed: Optional[int] = None, - ): + def __init__(self, *, # pylint: disable=too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + low_cost_partial_config: Optional[dict] = None, + seed: Optional[int] = None): """ Create an MLOS wrapper for FLAML. @@ -85,22 +82,14 @@ def __init__( configspace_to_flaml_space, ) - self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space( - self.optimizer_parameter_space - ) + self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) self.low_cost_partial_config = low_cost_partial_config self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {} self._suggested_config: Optional[dict] - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -122,10 +111,9 @@ def _register( if metadata is not None: warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning) - for (_, config), (_, score) in zip(configs.astype("O").iterrows(), scores.iterrows()): + for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()): cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict() - ) + self.optimizer_parameter_space, values=config.to_dict()) if cs_config in self.evaluated_samples: warn(f"Configuration {config} was already registered", UserWarning) self.evaluated_samples[cs_config] = EvaluatedSample( @@ -133,9 +121,7 @@ def _register( score=float(np.average(score.astype(float), weights=self._objective_weights)), ) - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -158,13 +144,8 @@ def _suggest( config: dict = self._get_next_config() return pd.DataFrame(config, index=[0]), None - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: @@ -219,14 +200,16 @@ def _get_next_config(self) -> dict: dict(normalize_config(self.optimizer_parameter_space, conf)) for conf in self.evaluated_samples ] - evaluated_rewards = [s.score for s in self.evaluated_samples.values()] + evaluated_rewards = [ + s.score for s in self.evaluated_samples.values() + ] # Warm start FLAML optimizer self._suggested_config = None tune.run( self._target_function, config=self.flaml_parameter_space, - mode="min", + mode='min', metric=self._METRIC_NAME, points_to_evaluate=points_to_evaluate, evaluated_rewards=evaluated_rewards, @@ -235,6 +218,6 @@ def _get_next_config(self) -> dict: verbose=0, ) if self._suggested_config is None: - raise RuntimeError("FLAML did not produce a suggestion") + raise RuntimeError('FLAML did not produce a suggestion') return self._suggested_config # type: ignore[unreachable] diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index f96bce7075..4ab9db5a2f 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -24,14 +24,11 @@ class BaseOptimizer(metaclass=ABCMeta): Optimizer abstract base class defining the basic interface. """ - def __init__( - self, - *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - ): + def __init__(self, *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None): """ Create a new instance of the base optimizer. @@ -47,9 +44,8 @@ def __init__( The space adapter class to employ for parameter space transformations. """ self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space - self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = ( + self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \ parameter_space if space_adapter is None else space_adapter.target_parameter_space - ) if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space: raise ValueError("Given parameter space differs from the one given to space adapter") @@ -72,14 +68,8 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: """Get the space adapter instance (if any).""" return self._space_adapter - def register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. Parameters @@ -97,37 +87,29 @@ def register( """ # Do some input validation. assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(scores.columns) == set( - self._optimization_targets - ), "Mismatched optimization targets." - assert self._has_context is None or self._has_context ^ ( - context is None - ), "Context must always be added or never be added." - assert len(configs) == len(scores), "Mismatched number of configs and scores." + assert set(scores.columns) == set(self._optimization_targets), \ + "Mismatched optimization targets." + assert self._has_context is None or self._has_context ^ (context is None), \ + "Context must always be added or never be added." + assert len(configs) == len(scores), \ + "Mismatched number of configs and scores." if context is not None: - assert len(configs) == len(context), "Mismatched number of configs and context." - assert configs.shape[1] == len( - self.parameter_space.values() - ), "Mismatched configuration shape." + assert len(configs) == len(context), \ + "Mismatched number of configs and context." + assert configs.shape[1] == len(self.parameter_space.values()), \ + "Mismatched configuration shape." self._observations.append((configs, scores, context)) self._has_context = context is not None if self._space_adapter: configs = self._space_adapter.inverse_transform(configs) - assert configs.shape[1] == len( - self.optimizer_parameter_space.values() - ), "Mismatched configuration shape after inverse transform." + assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \ + "Mismatched configuration shape after inverse transform." return self._register(configs=configs, scores=scores, context=context) @abstractmethod - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Parameters @@ -140,11 +122,10 @@ def _register( context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def suggest( - self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def suggest(self, *, context: Optional[pd.DataFrame] = None, + defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. @@ -168,21 +149,18 @@ def suggest( configuration = self.space_adapter.inverse_transform(configuration) else: configuration, metadata = self._suggest(context=context) - assert len(configuration) == 1, "Suggest must return a single configuration." - assert set(configuration.columns).issubset( - set(self.optimizer_parameter_space) - ), "Optimizer suggested a configuration that does not match the expected parameter space." + assert len(configuration) == 1, \ + "Suggest must return a single configuration." + assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \ + "Optimizer suggested a configuration that does not match the expected parameter space." if self._space_adapter: configuration = self._space_adapter.transform(configuration) - assert set(configuration.columns).issubset( - set(self.parameter_space) - ), "Space adapter produced a configuration that does not match the expected parameter space." + assert set(configuration.columns).issubset(set(self.parameter_space)), \ + "Space adapter produced a configuration that does not match the expected parameter space." return configuration, metadata @abstractmethod - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Parameters @@ -198,16 +176,12 @@ def _suggest( metadata : Optional[pd.DataFrame] The metadata associated with the given configuration used for evaluations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs as "pending". That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful for executing multiple trials in parallel, retry logic, etc. @@ -221,7 +195,7 @@ def register_pending( metadata : Optional[pd.DataFrame] Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ @@ -236,17 +210,11 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data raise ValueError("No observations registered yet.") configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True) scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True) - contexts = pd.concat( - [ - pd.DataFrame() if context is None else context - for _, _, context in self._observations - ] - ).reset_index(drop=True) + contexts = pd.concat([pd.DataFrame() if context is None else context + for _, _, context in self._observations]).reset_index(drop=True) return (configs, scores, contexts if len(contexts.columns) > 0 else None) - def get_best_observations( - self, *, n_max: int = 1 - ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, context). Default is N=1. The columns are ordered in ASCENDING order of the optimization targets. @@ -266,7 +234,8 @@ def get_best_observations( raise ValueError("No observations registered yet.") (configs, scores, contexts) = self.get_observations() idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index - return (configs.loc[idx], scores.loc[idx], None if contexts is None else contexts.loc[idx]) + return (configs.loc[idx], scores.loc[idx], + None if contexts is None else contexts.loc[idx]) def cleanup(self) -> None: """ @@ -284,7 +253,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: j = 0 for param in self.optimizer_parameter_space.values(): if isinstance(param, ConfigSpace.CategoricalHyperparameter): - for offset, val in enumerate(param.choices): + for (offset, val) in enumerate(param.choices): if config[i][j + offset] == 1: df_dict[param.name].append(val) break diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index bf6f85ff88..0af785ef20 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -24,14 +24,8 @@ class RandomOptimizer(BaseOptimizer): The parameter space to optimize. """ - def _register( - self, - *, - configs: pd.DataFrame, - scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: """Registers the given configs and scores. Doesn't do anything on the RandomOptimizer except storing configs for logging. @@ -56,9 +50,7 @@ def _register( warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning) # should we pop them from self.pending_observations? - def _suggest( - self, *, context: Optional[pd.DataFrame] = None - ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """Suggests a new configuration. Sampled at random using ConfigSpace. @@ -79,17 +71,9 @@ def _suggest( if context is not None: # not sure how that works here? warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) - return ( - pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), - None, - ) - - def register_pending( - self, - *, - configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None, - ) -> None: + return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None + + def register_pending(self, *, configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: raise NotImplementedError() # self._pending_observations.append((configs, context)) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 73e7f37dc3..2e2f585590 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -15,8 +15,8 @@ from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter __all__ = [ - "IdentityAdapter", - "LlamaTuneAdapter", + 'IdentityAdapter', + 'LlamaTuneAdapter', ] @@ -35,7 +35,7 @@ class SpaceAdapterType(Enum): # ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType]) # To address this, we add a test for complete coverage of the enum. ConcreteSpaceAdapter = TypeVar( - "ConcreteSpaceAdapter", + 'ConcreteSpaceAdapter', IdentityAdapter, LlamaTuneAdapter, ) @@ -47,12 +47,10 @@ class SpaceAdapterFactory: # pylint: disable=too-few-public-methods @staticmethod - def create( - *, - parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None, - ) -> ConcreteSpaceAdapter: # type: ignore[type-var] + def create(*, + parameter_space: ConfigSpace.ConfigurationSpace, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var] """ Create a new space adapter instance, given the parameter space and potential space adapter options. @@ -77,7 +75,8 @@ def create( space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, **space_adapter_kwargs + orig_parameter_space=parameter_space, + **space_adapter_kwargs ) return space_adapter diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 58d07763f6..6c3a86fc8a 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -46,7 +46,7 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: """ Target parameter space that is fed to the underlying optimizer. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @@ -64,7 +64,7 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Pandas dataframe with a single row, containing the translated configuration. Column names are the parameter names of the original parameter space. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: @@ -84,4 +84,4 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: Dataframe of the translated configurations / parameters. The columns are the parameter names of the target parameter space and the rows are the configurations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index b8abdedfeb..4d3a925cbc 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -19,7 +19,7 @@ from mlos_core.util import normalize_config -class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes +class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes """ Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. @@ -28,21 +28,18 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-a DEFAULT_NUM_LOW_DIMS = 16 """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" - DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 + DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 """Default percentage of bias for each special parameter value""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 """Default number of (max) unique values of each parameter, when space discretization is used""" - def __init__( - self, - *, - orig_parameter_space: ConfigSpace.ConfigurationSpace, - num_low_dims: int = DEFAULT_NUM_LOW_DIMS, - special_param_values: Optional[dict] = None, - max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, - use_approximate_reverse_mapping: bool = False, - ): + def __init__(self, *, + orig_parameter_space: ConfigSpace.ConfigurationSpace, + num_low_dims: int = DEFAULT_NUM_LOW_DIMS, + special_param_values: Optional[dict] = None, + max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, + use_approximate_reverse_mapping: bool = False): """ Create a space adapter that employs LlamaTune's techniques. @@ -61,9 +58,7 @@ def __init__( super().__init__(orig_parameter_space=orig_parameter_space) if num_low_dims >= len(orig_parameter_space): - raise ValueError( - "Number of target config space dimensions should be less than those of original config space." - ) + raise ValueError("Number of target config space dimensions should be less than those of original config space.") # Validate input special param values dict special_param_values = special_param_values or {} @@ -95,10 +90,9 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] - for _, config in configurations.astype("O").iterrows(): + for (_, config) in configurations.astype('O').iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict() - ) + self.orig_parameter_space, values=config.to_dict()) target_config = self._suggested_configs.get(configuration, None) # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection @@ -110,15 +104,12 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # Default configuration should always be registerable. pass elif not self._use_approximate_reverse_mapping: - raise ValueError( - f"{repr(configuration)}\n" - "The above configuration was not suggested by the optimizer. " - "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " - "previously by the optimizer can be registered." - ) + raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. " + "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " + "previously by the optimizer can be registered.") # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. - if getattr(self, "_pinv_matrix", None) is None: + if getattr(self, '_pinv_matrix', None) is None: self._try_generate_approx_inverse_mapping() # Replace NaNs with zeros for inactive hyperparameters @@ -127,27 +118,19 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # NOTE: applying special value biasing is not possible vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) - target_config = ConfigSpace.Configuration( - self.target_parameter_space, vector=target_config_vector - ) + target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector) target_configurations.append(target_config) - return pd.DataFrame( - target_configurations, columns=list(self.target_parameter_space.keys()) - ) + return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys())) def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: if len(configuration) != 1: - raise ValueError( - "Configuration dataframe must contain exactly 1 row. " - f"Found {len(configuration)} rows." - ) + raise ValueError("Configuration dataframe must contain exactly 1 row. " + f"Found {len(configuration)} rows.") target_values_dict = configuration.iloc[0].to_dict() - target_configuration = ConfigSpace.Configuration( - self.target_parameter_space, values=target_values_dict - ) + target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict) orig_values_dict = self._transform(target_values_dict) orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict) @@ -155,13 +138,9 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: # Add to inverse dictionary -- needed for registering the performance later self._suggested_configs[orig_configuration] = target_configuration - return pd.DataFrame( - [list(orig_configuration.values())], columns=list(orig_configuration.keys()) - ) + return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) - def _construct_low_dim_space( - self, num_low_dims: int, max_unique_values_per_param: Optional[int] - ) -> None: + def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: """Constructs the low-dimensional parameter (potentially discretized) search space. Parameters @@ -177,7 +156,7 @@ def _construct_low_dim_space( q_scaler = None if max_unique_values_per_param is None: hyperparameters = [ - ConfigSpace.UniformFloatHyperparameter(name=f"dim_{idx}", lower=-1, upper=1) + ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1) for idx in range(num_low_dims) ] else: @@ -185,9 +164,7 @@ def _construct_low_dim_space( # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. hyperparameters = [ - ConfigSpace.UniformIntegerHyperparameter( - name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param - ) + ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param) for idx in range(num_low_dims) ] @@ -201,9 +178,7 @@ def _construct_low_dim_space( # Construct low-dimensional parameter search space config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) - config_space.random = ( - self._random_state - ) # use same random state as in original parameter space + config_space.random = self._random_state # use same random state as in original parameter space config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -241,10 +216,10 @@ def _transform(self, configuration: dict) -> dict: # Clip value to force it to fall in [0, 1] # NOTE: HeSBO projection ensures that theoretically but due to # floating point ops nuances this is not always guaranteed - value = max(0.0, min(1.0, norm_value)) # pylint: disable=redefined-loop-name + value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name if isinstance(param, ConfigSpace.CategoricalHyperparameter): - index = int(value * len(param.choices)) # truncate integer part + index = int(value * len(param.choices)) # truncate integer part index = max(0, min(len(param.choices) - 1, index)) # NOTE: potential rounding here would be unfair to first & last values orig_value = param.choices[index] @@ -252,20 +227,16 @@ def _transform(self, configuration: dict) -> dict: if param.name in self._special_param_values_dict: value = self._special_param_value_scaler(param, value) - orig_value = param._transform(value) # pylint: disable=protected-access + orig_value = param._transform(value) # pylint: disable=protected-access orig_value = max(param.lower, min(param.upper, orig_value)) else: - raise NotImplementedError( - "Only Categorical, Integer, and Float hyperparameters are currently supported." - ) + raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.") original_config[param.name] = orig_value return original_config - def _special_param_value_scaler( - self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float - ) -> float: + def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. Parameters @@ -284,7 +255,7 @@ def _special_param_value_scaler( special_values_list = self._special_param_values_dict[param.name] # Check if input value corresponds to some special value - perc_sum = 0.0 + perc_sum = 0. ret: float for special_value, biasing_perc in special_values_list: perc_sum += biasing_perc @@ -293,9 +264,8 @@ def _special_param_value_scaler( return ret # Scale input value uniformly to non-special values - ret = param._inverse_transform( # pylint: disable=protected-access - param._transform_scalar((input_value - perc_sum) / (1 - perc_sum)) - ) # pylint: disable=protected-access + ret = param._inverse_transform( # pylint: disable=protected-access + param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access return ret # pylint: disable=too-complex,too-many-branches @@ -324,10 +294,8 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non hyperparameter = self.orig_parameter_space[param] if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter): - raise NotImplementedError( - error_prefix + f"Parameter '{param}' is not supported. " - "Only Integer Hyperparameters are currently supported." - ) + raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. " + "Only Integer Hyperparameters are currently supported.") if isinstance(value, int): # User specifies a single special value -- default biasing percentage is used @@ -338,57 +306,34 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non elif isinstance(value, list) and value: if all(isinstance(t, int) for t in value): # User specifies list of special values - tuple_list = [ - (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value - ] - elif all( - isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value - ): + tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value] + elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value): # User specifies list of tuples; each tuple defines the special value and the biasing percentage tuple_list = value else: - raise ValueError( - error_prefix + f"Invalid format in value list for parameter '{param}'. " - f"Special value list should contain either integers, or (special value, biasing %) tuples." - ) + raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. " + f"Special value list should contain either integers, or (special value, biasing %) tuples.") else: - raise ValueError( - error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " - "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples." - ) + raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " + "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.") # Are user-specified special values valid? if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list): - raise ValueError( - error_prefix - + f"One (or more) special values are outside of parameter '{param}' value domain." - ) + raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.") # Are user-provided special values unique? if len(set(v for v, _ in tuple_list)) != len(tuple_list): - raise ValueError( - error_prefix - + f"One (or more) special values are defined more than once for parameter '{param}'." - ) + raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.") # Are biasing percentages valid? if not all(0 < perc < 1 for _, perc in tuple_list): - raise ValueError( - error_prefix - + f"One (or more) biasing percentages for parameter '{param}' are invalid: " - "i.e., fall outside (0, 1) range." - ) + raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: " + "i.e., fall outside (0, 1) range.") total_percentage = sum(perc for _, perc in tuple_list) - if total_percentage >= 1.0: - raise ValueError( - error_prefix - + f"Total special values percentage for parameter '{param}' surpass 100%." - ) + if total_percentage >= 1.: + raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.") # ... and reasonable? if total_percentage >= 0.5: - warn( - f"Total special values percentage for parameter '{param}' exceeds 50%.", - UserWarning, - ) + warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning) sanitized_dict[param] = tuple_list @@ -410,12 +355,9 @@ def _try_generate_approx_inverse_mapping(self) -> None: pinv, ) - warn( - "Trying to register a configuration that was not previously suggested by the optimizer. " - + "This inverse configuration transformation is typically not supported. " - + "However, we will try to register this configuration using an *experimental* method.", - UserWarning, - ) + warn("Trying to register a configuration that was not previously suggested by the optimizer. " + + "This inverse configuration transformation is typically not supported. " + + "However, we will try to register this configuration using an *experimental* method.", UserWarning) orig_space_num_dims = len(list(self.orig_parameter_space.values())) target_space_num_dims = len(list(self.target_parameter_space.values())) @@ -429,7 +371,5 @@ def _try_generate_approx_inverse_mapping(self) -> None: try: self._pinv_matrix = pinv(proj_matrix) except LinAlgError as err: - raise RuntimeError( - f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}" - ) from err + raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims) diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 1b9e61ad91..d6918f9891 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -27,9 +27,7 @@ FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] -def configspace_to_flaml_space( - config_space: ConfigSpace.ConfigurationSpace, -) -> Dict[str, FlamlDomain]: +def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: """Converts a ConfigSpace.ConfigurationSpace to dict. Parameters @@ -52,19 +50,13 @@ def configspace_to_flaml_space( def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range - return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper - ) + return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): - return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper + 1 - ) + return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: - raise ValueError( - "FLAML doesn't support categorical parameters with non-uniform probabilities." - ) - return flaml.tune.choice(parameter.choices) # TODO: set order? + raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.") + return flaml.tune.choice(parameter.choices) # TODO: set order? raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") return {param.name: _one_parameter_convert(param) for param in config_space.values()} diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index 81d5151b20..a8ad146205 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -21,7 +21,7 @@ from typing_extensions import TypeAlias -T = TypeVar("T") +T = TypeVar('T') def get_all_submodules(pkg: TypeAlias) -> List[str]: @@ -30,9 +30,7 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: Useful for dynamically enumerating subclasses. """ submodules = [] - for _, submodule_name, _ in walk_packages( - pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None - ): + for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): submodules.append(submodule_name) return submodules @@ -43,8 +41,7 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( - s for c in cls.__subclasses__() for s in _get_all_subclasses(c) - ) + s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: @@ -60,11 +57,5 @@ def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> pkg = import_module(pkg_name) submodules = get_all_submodules(pkg) assert submodules - return sorted( - [ - subclass - for subclass in _get_all_subclasses(cls) - if not getattr(subclass, "__abstractmethods__", None) - ], - key=lambda c: (c.__module__, c.__name__), - ) + return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], + key=lambda c: (c.__module__, c.__name__)) diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index 775afa2455..c7a94dfcc4 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -17,27 +17,24 @@ @pytest.mark.filterwarnings("error:Not Implemented") -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_context_not_implemented_warning( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict]) -> None: """ Make sure we raise warnings for the functionality that has not been implemented yet. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) suggestion, _metadata = optimizer.suggest() - scores = pd.DataFrame({"score": [1]}) + scores = pd.DataFrame({'score': [1]}) context = pd.DataFrame([["something"]]) with pytest.raises(UserWarning): diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 504c91eac7..39231bec5c 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -18,9 +18,9 @@ def configuration_space() -> CS.ConfigurationSpace: # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. - space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) + space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) # Add a categorical hyperparameter with 3 possible values. - space.add_hyperparameter(CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"])) + space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"])) # Add a discrete input dimension between 0 and 10. - space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10)) + space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10)) return space diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 7fe793a824..725d92fbe9 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -23,13 +23,11 @@ def data_frame() -> pd.DataFrame: Toy data frame corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.DataFrame( - { - "y": ["a", "b", "c"], - "x": [0.1, 0.2, 0.3], - "z": [1, 5, 8], - } - ) + return pd.DataFrame({ + 'y': ['a', 'b', 'c'], + 'x': [0.1, 0.2, 0.3], + 'z': [1, 5, 8], + }) @pytest.fixture @@ -38,13 +36,11 @@ def one_hot_data_frame() -> npt.NDArray: One-hot encoding of the `data_frame` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array( - [ - [0.1, 1.0, 0.0, 0.0, 1.0], - [0.2, 0.0, 1.0, 0.0, 5.0], - [0.3, 0.0, 0.0, 1.0, 8.0], - ] - ) + return np.array([ + [0.1, 1.0, 0.0, 0.0, 1.0], + [0.2, 0.0, 1.0, 0.0, 5.0], + [0.3, 0.0, 0.0, 1.0, 8.0], + ]) @pytest.fixture @@ -53,13 +49,11 @@ def series() -> pd.Series: Toy series corresponding to the `configuration_space` hyperparameters. The columns are deliberately *not* in alphabetic order. """ - return pd.Series( - { - "y": "b", - "x": 0.4, - "z": 3, - } - ) + return pd.Series({ + 'y': 'b', + 'x': 0.4, + 'z': 3, + }) @pytest.fixture @@ -68,11 +62,9 @@ def one_hot_series() -> npt.NDArray: One-hot encoding of the `series` above. The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array( - [ - [0.4, 0.0, 1.0, 0.0, 3], - ] - ) + return np.array([ + [0.4, 0.0, 1.0, 0.0, 3], + ]) @pytest.fixture @@ -82,40 +74,39 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ return SmacOptimizer( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], ) -def test_to_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray -) -> None: +def test_to_1hot_data_frame(optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray) -> None: """ Toy problem to test one-hot encoding of dataframe. """ assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) -def test_to_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray -) -> None: +def test_to_1hot_series(optimizer: BaseOptimizer, + series: pd.Series, one_hot_series: npt.NDArray) -> None: """ Toy problem to test one-hot encoding of series. """ assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) -def test_from_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray -) -> None: +def test_from_1hot_data_frame(optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray) -> None: """ Toy problem to test one-hot decoding of dataframe. """ assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() -def test_from_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray -) -> None: +def test_from_1hot_series(optimizer: BaseOptimizer, + series: pd.Series, + one_hot_series: npt.NDArray) -> None: """ Toy problem to test one-hot decoding of series. """ @@ -144,9 +135,8 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No assert (series_round_trip.z == series.z).all() -def test_round_trip_reverse_data_frame( - optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray -) -> None: +def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, + one_hot_data_frame: npt.NDArray) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ @@ -154,7 +144,8 @@ def test_round_trip_reverse_data_frame( assert round_trip == pytest.approx(one_hot_data_frame) -def test_round_trip_reverse_series(optimizer: BaseOptimizer, one_hot_series: npt.NDArray) -> None: +def test_round_trip_reverse_series(optimizer: BaseOptimizer, + one_hot_series: npt.NDArray) -> None: """ Round-trip test for one-hot-decoding and then encoding of a numpy array. """ diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 870943c346..0b9d624a7a 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -20,15 +20,10 @@ _LOG = logging.getLogger(__name__) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_multi_target_opt_wrong_weights( - optimizer_class: Type[BaseOptimizer], kwargs: dict -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: """ Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. @@ -36,29 +31,23 @@ def test_multi_target_opt_wrong_weights( with pytest.raises(ValueError): optimizer_class( parameter_space=CS.ConfigurationSpace(seed=SEED), - optimization_targets=["main_score", "other_score"], + optimization_targets=['main_score', 'other_score'], objective_weights=[1], - **kwargs, + **kwargs ) -@pytest.mark.parametrize( - ("objective_weights"), - [ - [2, 1], - [0.5, 0.5], - None, - ], -) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_multi_target_opt( - objective_weights: Optional[List[float]], optimizer_class: Type[BaseOptimizer], kwargs: dict -) -> None: +@pytest.mark.parametrize(('objective_weights'), [ + [2, 1], + [0.5, 0.5], + None, +]) +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_multi_target_opt(objective_weights: Optional[List[float]], + optimizer_class: Type[BaseOptimizer], + kwargs: dict) -> None: """ Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. @@ -67,21 +56,21 @@ def test_multi_target_opt( def objective(point: pd.DataFrame) -> pd.DataFrame: # mix of hyperparameters, optimal is to select the highest possible - return pd.DataFrame( - { - "main_score": point.x + point.y, - "other_score": point.x**2 + point.y**2, - } - ) + return pd.DataFrame({ + "main_score": point.x + point.y, + "other_score": point.x ** 2 + point.y ** 2, + }) input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) optimizer = optimizer_class( parameter_space=input_space, - optimization_targets=["main_score", "other_score"], + optimization_targets=['main_score', 'other_score'], objective_weights=objective_weights, **kwargs, ) @@ -96,28 +85,27 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {"x", "y"} + assert set(suggestion.columns) == {'x', 'y'} # Check suggestion values are the expected dtype assert isinstance(suggestion.x.iloc[0], np.integer) assert isinstance(suggestion.y.iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() - ) + optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. observation = objective(suggestion) assert isinstance(observation, pd.DataFrame) - assert set(observation.columns) == {"main_score", "other_score"} + assert set(observation.columns) == {'main_score', 'other_score'} optimizer.register(configs=suggestion, scores=observation) (best_config, best_score, best_context) = optimizer.get_best_observations() assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y"} - assert set(best_score.columns) == {"main_score", "other_score"} + assert set(best_config.columns) == {'x', 'y'} + assert set(best_score.columns) == {'main_score', 'other_score'} assert best_config.shape == (1, 2) assert best_score.shape == (1, 2) @@ -125,7 +113,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y"} - assert set(all_scores.columns) == {"main_score", "other_score"} + assert set(all_configs.columns) == {'x', 'y'} + assert set(all_scores.columns) == {'main_score', 'other_score'} assert all_configs.shape == (max_iterations, 2) assert all_scores.shape == (max_iterations, 2) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index d5d00d0692..5fd28ca1ed 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -32,24 +32,20 @@ _LOG.setLevel(logging.DEBUG) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_create_optimizer_and_suggest( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: """ Test that we can create an optimizer and get a suggestion from it. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) assert optimizer is not None @@ -66,17 +62,11 @@ def test_create_optimizer_and_suggest( optimizer.register_pending(configs=suggestion, metadata=metadata) -@pytest.mark.parametrize( - ("optimizer_class", "kwargs"), - [ - *[(member.value, {}) for member in OptimizerType], - ], -) -def test_basic_interface_toy_problem( - configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict], -) -> None: +@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ + *[(member.value, {}) for member in OptimizerType], +]) +def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers. """ @@ -87,15 +77,17 @@ def test_basic_interface_toy_problem( if optimizer_class == OptimizerType.SMAC.value: # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. - kwargs["max_trials"] = max_iterations * 2 + kwargs['max_trials'] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: - return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)}) + return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)}) # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=['score'], + **kwargs ) with pytest.raises(ValueError, match="No observations"): @@ -108,12 +100,12 @@ def objective(x: pd.Series) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {"x", "y", "z"} + assert set(suggestion.columns) == {'x', 'y', 'z'} # check that suggestion is in the space configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict()) # Raises an error if outside of configuration space configuration.is_valid_configuration() - observation = objective(suggestion["x"]) + observation = objective(suggestion['x']) assert isinstance(observation, pd.DataFrame) optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -121,8 +113,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y", "z"} - assert set(best_score.columns) == {"score"} + assert set(best_config.columns) == {'x', 'y', 'z'} + assert set(best_score.columns) == {'score'} assert best_config.shape == (1, 3) assert best_score.shape == (1, 1) assert best_score.score.iloc[0] < -5 @@ -131,8 +123,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y", "z"} - assert set(all_scores.columns) == {"score"} + assert set(all_configs.columns) == {'x', 'y', 'z'} + assert set(all_scores.columns) == {'score'} assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) @@ -145,36 +137,27 @@ def objective(x: pd.Series) -> pd.DataFrame: assert pred_all.shape == (20,) -@pytest.mark.parametrize( - ("optimizer_type"), - [ - # Enumerate all supported Optimizers - # *[member for member in OptimizerType], - *list(OptimizerType), - ], -) +@pytest.mark.parametrize(('optimizer_type'), [ + # Enumerate all supported Optimizers + # *[member for member in OptimizerType], + *list(OptimizerType), +]) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ - assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member - - -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ], -) -def test_create_optimizer_with_factory_method( - configuration_space: CS.ConfigurationSpace, - optimizer_type: Optional[OptimizerType], - kwargs: Optional[dict], -) -> None: + assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member + + +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument +]) +def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, + optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: """ Test that we can create an optimizer via a factory. """ @@ -183,13 +166,13 @@ def test_create_optimizer_with_factory_method( if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -205,22 +188,16 @@ def test_create_optimizer_with_factory_method( assert myrepr.startswith(optimizer_type.value.__name__) -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ( - OptimizerType.SMAC, - { - # Test with default config. - "use_default_config": True, - # 'n_random_init': 10, - }, - ), - ], -) +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + (OptimizerType.SMAC, { + # Test with default config. + 'use_default_config': True, + # 'n_random_init': 10, + }), +]) def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers with llamatune space adapter. @@ -238,8 +215,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=1234) # Add two continuous inputs - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=3)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3)) # Initialize an optimizer that uses LlamaTune space adapter space_adapter_kwargs = { @@ -262,7 +239,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: llamatune_optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=llamatune_optimizer_kwargs, space_adapter_type=SpaceAdapterType.LLAMATUNE, @@ -271,7 +248,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Initialize an optimizer that uses the original space optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=optimizer_kwargs, ) @@ -280,7 +257,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space llamatune_n_random_init = 0 - opt_n_random_init = int(kwargs.get("n_random_init", 0)) + opt_n_random_init = int(kwargs.get('n_random_init', 0)) if optimizer_type == OptimizerType.SMAC: assert isinstance(optimizer, SmacOptimizer) assert isinstance(llamatune_optimizer, SmacOptimizer) @@ -301,10 +278,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() - _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( - 3.0, rel=1e-3 - ) # optimizer explores 1-dimensional space + _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0] + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space observation = objective(suggestion) llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -312,32 +287,28 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: best_observation = optimizer.get_best_observations() llamatune_best_observation = llamatune_optimizer.get_best_observations() - for best_config, best_score, best_context in (best_observation, llamatune_best_observation): + for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation): assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {"x", "y"} - assert set(best_score.columns) == {"score"} + assert set(best_config.columns) == {'x', 'y'} + assert set(best_score.columns) == {'score'} (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that - assert ( - best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] - or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] - ) + assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \ + best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] # Retrieve and check all observations - for all_configs, all_scores, all_contexts in ( - optimizer.get_observations(), - llamatune_optimizer.get_observations(), - ): + for (all_configs, all_scores, all_contexts) in ( + optimizer.get_observations(), llamatune_optimizer.get_observations()): assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {"x", "y"} - assert set(all_scores.columns) == {"score"} + assert set(all_configs.columns) == {'x', 'y'} + assert set(all_scores.columns) == {'score'} assert len(all_configs) == num_iters assert len(all_scores) == num_iters @@ -349,13 +320,12 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( - BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] -) +optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract] + pkg_name='mlos_core') assert optimizer_subclasses -@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses) +@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: """ Test that all optimizer classes are listed in the OptimizerType enum. @@ -364,19 +334,14 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: assert optimizer_class in optimizer_type_classes -@pytest.mark.parametrize( - ("optimizer_type", "kwargs"), - [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - ], -) -def test_mixed_numerics_type_input_space_types( - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] -) -> None: +@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument +]) +def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: """ Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -390,19 +355,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=["score"], + optimization_targets=['score'], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -416,14 +381,12 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) - assert (suggestion.columns == ["x", "y"]).all() + assert (suggestion.columns == ['x', 'y']).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion["x"].iloc[0], np.integer) - assert isinstance(suggestion["y"].iloc[0], np.floating) + assert isinstance(suggestion['x'].iloc[0], np.integer) + assert isinstance(suggestion['y'].iloc[0], np.floating) # Check that suggestion is in the space - test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() - ) + test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 13a28d242d..37b8aa3a69 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -20,33 +20,22 @@ def test_identity_adapter() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100) - ) + CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) - ) + CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) adapter = IdentityAdapter(orig_parameter_space=input_space) num_configs = 10 - for sampled_config in input_space.sample_configuration( - size=num_configs - ): # pylint: disable=not-an-iterable # (false positive) - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) assert target_config_df.equals(sampled_config_df) - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == sampled_config orig_config_df = adapter.transform(target_config_df) assert orig_config_df.equals(sampled_config_df) - orig_config = CS.Configuration( - adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) assert orig_config == sampled_config diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index cd1b250ab7..84dcd4e5c0 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -30,46 +30,34 @@ def construct_parameter_space( for idx in range(n_continuous_params): input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64) - ) + CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64)) for idx in range(n_integer_params): input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256) - ) + CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256)) for idx in range(n_categorical_params): input_space.add_hyperparameter( - CS.CategoricalHyperparameter( - name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)] - ) - ) + CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)])) return input_space -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) - ] - ), -) -def test_num_low_dims( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's low-to-high space projection method. """ @@ -78,7 +66,8 @@ def test_num_low_dims( # Number of target parameter space dimensions should be fewer than those of the original space with pytest.raises(ValueError): LlamaTuneAdapter( - orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys())) + orig_parameter_space=input_space, + num_low_dims=len(list(input_space.keys())) ) # Enable only low-dimensional space projections @@ -86,15 +75,13 @@ def test_num_low_dims( orig_parameter_space=input_space, num_low_dims=num_target_space_dims, special_param_values=None, - max_unique_values_per_param=None, + max_unique_values_per_param=None ) sampled_configs = adapter.target_parameter_space.sample_configuration(size=100) for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive) # Transform low-dim config to high-dim point/config - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid @@ -105,28 +92,18 @@ def test_num_low_dims( target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == sampled_config # Try inverse projection (i.e., high-to-low) for previously unseen configs unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25) - for ( - unseen_sampled_config - ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) - if ( - unseen_sampled_config in sampled_configs - ): # pylint: disable=unsupported-membership-test # (false positive) + for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) + if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive) continue - unseen_sampled_config_df = pd.DataFrame( - [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()) - ) + unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())) with pytest.raises(ValueError): - _ = adapter.inverse_transform( - unseen_sampled_config_df - ) # pylint: disable=redefined-variable-type + _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type def test_special_parameter_values_validation() -> None: @@ -135,14 +112,15 @@ def test_special_parameter_values_validation() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str", choices=[f"choice_{idx}" for idx in range(5)]) - ) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100)) - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100)) + CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)])) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100)) # Only UniformIntegerHyperparameters are currently supported with pytest.raises(NotImplementedError): - special_param_values_dict_1 = {"str": "choice_1"} + special_param_values_dict_1 = {'str': 'choice_1'} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -151,7 +129,7 @@ def test_special_parameter_values_validation() -> None: ) with pytest.raises(NotImplementedError): - special_param_values_dict_2 = {"cont": -1} + special_param_values_dict_2 = {'cont': -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -160,8 +138,8 @@ def test_special_parameter_values_validation() -> None: ) # Special value should belong to parameter value domain - with pytest.raises(ValueError, match="value domain"): - special_param_values_dict = {"int": -1} + with pytest.raises(ValueError, match='value domain'): + special_param_values_dict = {'int': -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -171,15 +149,15 @@ def test_special_parameter_values_validation() -> None: # Invalid dicts; ValueError should be thrown invalid_special_param_values_dicts: List[Dict[str, Any]] = [ - {"int-Q": 0}, # parameter does not exist - {"int": {0: 0.2}}, # invalid definition - {"int": 0.2}, # invalid parameter value - {"int": (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) - {"int": [0, 0]}, # duplicate special values - {"int": []}, # empty list - {"int": [{0: 0.2}]}, - {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct - {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values + {'int-Q': 0}, # parameter does not exist + {'int': {0: 0.2}}, # invalid definition + {'int': 0.2}, # invalid parameter value + {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) + {'int': [0, 0]}, # duplicate special values + {'int': []}, # empty list + {'int': [{0: 0.2}]}, + {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct + {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values ] for spv_dict in invalid_special_param_values_dicts: with pytest.raises(ValueError): @@ -192,13 +170,13 @@ def test_special_parameter_values_validation() -> None: # Biasing percentage of special value(s) are invalid invalid_special_param_values_dicts = [ - {"int": (0, 1.1)}, # >1 probability - {"int": (0, 0)}, # Zero probability - {"int": (0, -0.1)}, # Negative probability - {"int": (0, 20)}, # 2,000% instead of 20% - {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% - {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100% - {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. + {'int': (0, 1.1)}, # >1 probability + {'int': (0, 0)}, # Zero probability + {'int': (0, -0.1)}, # Negative probability + {'int': (0, 20)}, # 2,000% instead of 20% + {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% + {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100% + {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. ] for spv_dict in invalid_special_param_values_dicts: @@ -214,27 +192,21 @@ def test_special_parameter_values_validation() -> None: def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]: for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs): # Transform low-dim config to high-dim config - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) orig_config_df = adapter.transform(sampled_config_df) - orig_config = CS.Configuration( - adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() - ) + orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) yield orig_config -def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex +def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex """ Tests LlamaTune's special parameter values biasing methodology """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100) - ) + CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100)) num_configs = 400 bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE @@ -242,10 +214,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp # Single parameter; single special value special_param_value_dicts: List[Dict[str, Any]] = [ - {"int_1": 0}, - {"int_1": (0, bias_percentage)}, - {"int_1": [0]}, - {"int_1": [(0, bias_percentage)]}, + {'int_1': 0}, + {'int_1': (0, bias_percentage)}, + {'int_1': [0]}, + {'int_1': [(0, bias_percentage)]} ] for spv_dict in special_param_value_dicts: @@ -257,14 +229,13 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp ) special_value_occurrences = sum( - 1 for config in gen_random_configs(adapter, num_configs) if config["int_1"] == 0 - ) + 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0) assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences # Single parameter; multiple special values special_param_value_dicts = [ - {"int_1": [0, 1]}, - {"int_1": [(0, bias_percentage), (1, bias_percentage)]}, + {'int_1': [0, 1]}, + {'int_1': [(0, bias_percentage), (1, bias_percentage)]} ] for spv_dict in special_param_value_dicts: @@ -277,9 +248,9 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp special_values_occurrences = {0: 0, 1: 0} for config in gen_random_configs(adapter, num_configs): - if config["int_1"] == 0: + if config['int_1'] == 0: special_values_occurrences[0] += 1 - elif config["int_1"] == 1: + elif config['int_1'] == 1: special_values_occurrences[1] += 1 assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0] @@ -287,8 +258,8 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp # Multiple parameters; multiple special values; different biasing percentage spv_dict = { - "int_1": [(0, bias_percentage), (1, bias_percentage / 2)], - "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)], + 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)], + 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)] } adapter = LlamaTuneAdapter( orig_parameter_space=input_space, @@ -298,30 +269,24 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-comp ) special_values_instances: Dict[str, Dict[int, int]] = { - "int_1": {0: 0, 1: 0}, - "int_2": {2: 0, 100: 0}, + 'int_1': {0: 0, 1: 0}, + 'int_2': {2: 0, 100: 0}, } for config in gen_random_configs(adapter, num_configs): - if config["int_1"] == 0: - special_values_instances["int_1"][0] += 1 - elif config["int_1"] == 1: - special_values_instances["int_1"][1] += 1 - - if config["int_2"] == 2: - special_values_instances["int_2"][2] += 1 - elif config["int_2"] == 100: - special_values_instances["int_2"][100] += 1 - - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances["int_1"][0] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_1"][ - 1 - ] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_2"][ - 2 - ] - assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances[ - "int_2" - ][100] + if config['int_1'] == 0: + special_values_instances['int_1'][0] += 1 + elif config['int_1'] == 1: + special_values_instances['int_1'][1] += 1 + + if config['int_2'] == 2: + special_values_instances['int_2'][2] += 1 + elif config['int_2'] == 100: + special_values_instances['int_2'][100] += 1 + + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2] + assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100] def test_max_unique_values_per_param() -> None: @@ -330,22 +295,18 @@ def test_max_unique_values_per_param() -> None: """ # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) - ) + CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10) - ) + CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100)) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048) - ) + CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10)) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) - ) + CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048)) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name="str_2", choices=[f"choice_{idx}" for idx in range(10)]) - ) + CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)])) # Restrict the number of unique parameter values num_configs = 200 @@ -368,30 +329,23 @@ def test_max_unique_values_per_param() -> None: assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) - ] - ), -) -def test_approx_inverse_mapping( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals """ Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse. """ @@ -406,11 +360,9 @@ def test_approx_inverse_mapping( use_approximate_reverse_mapping=False, ) - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.raises(ValueError): - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) _ = adapter.inverse_transform(sampled_config_df) # Enable low-dimensional space projection *and* reverse mapping @@ -423,63 +375,41 @@ def test_approx_inverse_mapping( ) # Warning should be printed the first time - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.warns(UserWarning): - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) adapter.target_parameter_space.check_configuration(target_config) # Test inverse transform with 100 random configs for _ in range(100): - sampled_config = input_space.sample_configuration() # size=1) - sampled_config_df = pd.DataFrame( - [sampled_config.values()], columns=list(sampled_config.keys()) - ) + sampled_config = input_space.sample_configuration() # size=1) + sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) adapter.target_parameter_space.check_configuration(target_config) -@pytest.mark.parametrize( - ("num_low_dims", "special_param_values", "max_unique_values_per_param"), - ( - [ - (num_low_dims, special_param_values, max_unique_values_per_param) - for num_low_dims in (8, 16) - for special_param_values in ( - {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]}, - { - "int_1": (-1, 0.1), - "int_2": -1, - "int_3": (-1, 0.3), - "int_4": [(-1, 0.1), (0, 0.2)], - }, - ) - for max_unique_values_per_param in (50, 250) - ] - ), -) -def test_llamatune_pipeline( - num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int -) -> None: +@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([ + (num_low_dims, special_param_values, max_unique_values_per_param) + for num_low_dims in (8, 16) + for special_param_values in ( + {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]}, + {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]}, + ) + for max_unique_values_per_param in (50, 250) +])) +def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: """ Tests LlamaTune space adapter when all components are active. """ # pylint: disable=too-many-locals # Define config space with a mix of different parameter types - input_space = construct_parameter_space( - n_continuous_params=10, n_integer_params=10, n_categorical_params=5 - ) + input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=num_low_dims, @@ -489,14 +419,12 @@ def test_llamatune_pipeline( special_value_occurrences = { param: {special_value: 0 for special_value, _ in tuples_list} - for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access } unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()} num_configs = 1000 - for config in adapter.target_parameter_space.sample_configuration( - size=num_configs - ): # pylint: disable=not-an-iterable + for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # Transform low-dim config to high-dim point/config sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys())) orig_config_df = adapter.transform(sampled_config_df) @@ -507,9 +435,7 @@ def test_llamatune_pipeline( # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() - ) + target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) assert target_config == config for param, value in orig_config.items(): @@ -523,48 +449,35 @@ def test_llamatune_pipeline( # Ensure that occurrences of special values do not significantly deviate from expected eps = 0.2 - for ( - param, - tuples_list, - ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access for value, bias_percentage in tuples_list: - assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[ - param - ][value] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value] # Ensure that number of unique values is less than the maximum number allowed for _, unique_values in unique_values_dict.items(): assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize( - ("num_target_space_dims", "param_space_kwargs"), - ( - [ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, - {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), - "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) - ] - ), -) -def test_deterministic_behavior_for_same_seed( - num_target_space_dims: int, param_space_kwargs: dict -) -> None: +@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, + {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), + 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) +])) +def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: """ Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ - def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) @@ -577,9 +490,7 @@ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: use_approximate_reverse_mapping=False, ) - sample_configs: List[CS.Configuration] = ( - adapter.target_parameter_space.sample_configuration(size=100) - ) + sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100) return sample_configs assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 6e5eab7d96..5390f97c5f 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -23,47 +23,39 @@ from mlos_core.tests import get_all_concrete_subclasses -@pytest.mark.parametrize( - ("space_adapter_type"), - [ - # Enumerate all supported SpaceAdapters - # *[member for member in SpaceAdapterType], - *list(SpaceAdapterType), - ], -) +@pytest.mark.parametrize(('space_adapter_type'), [ + # Enumerate all supported SpaceAdapters + # *[member for member in SpaceAdapterType], + *list(SpaceAdapterType), +]) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: """ Test that all optimizer types are listed in the ConcreteOptimizer constraints. """ # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] + assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] -@pytest.mark.parametrize( - ("space_adapter_type", "kwargs"), - [ - # Default space adapter - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in SpaceAdapterType], - ], -) -def test_create_space_adapter_with_factory_method( - space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] -) -> None: +@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [ + # Default space adapter + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in SpaceAdapterType], +]) +def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1)) # Adjust some kwargs for specific space adapters if space_adapter_type is SpaceAdapterType.LLAMATUNE: if kwargs is None: kwargs = {} - kwargs.setdefault("num_low_dims", 1) + kwargs.setdefault('num_low_dims', 1) space_adapter: BaseSpaceAdapter if space_adapter_type is None: @@ -81,25 +73,21 @@ def test_create_space_adapter_with_factory_method( assert space_adapter is not None assert space_adapter.orig_parameter_space is not None myrepr = repr(space_adapter) - assert myrepr.startswith( - space_adapter_type.value.__name__ - ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}" + assert myrepr.startswith(space_adapter_type.value.__name__), \ + f"Expected {space_adapter_type.value.__name__} but got {myrepr}" # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses( - BaseSpaceAdapter, pkg_name="mlos_core" -) # type: ignore[type-abstract] +space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \ + get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract] assert space_adapter_subclasses -@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses) +@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: """ Test that all space adapter classes are listed in the SpaceAdapterType enum. """ - space_adapter_type_classes = { - space_adapter_type.value for space_adapter_type in SpaceAdapterType - } + space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index f7cde8ae88..dee9251652 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -41,9 +41,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None: assert np.isclose(frequencies.sum(), 1) _f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies) - assert np.isclose(kurtosis, -1.2, atol=0.1) - assert p_value > 0.3 - assert f_p_value > 0.5 + assert np.isclose(kurtosis, -1.2, atol=.1) + assert p_value > .3 + assert f_p_value > .5 def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None: @@ -70,14 +70,13 @@ def invalid_conversion_function(*args: Any) -> NoReturn: """ A quick dummy function for the base class to make pylint happy. """ - raise NotImplementedError("subclass must override conversion_function") + raise NotImplementedError('subclass must override conversion_function') class BaseConversion(metaclass=ABCMeta): """ Base class for testing optimizer space conversions. """ - conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod @@ -151,8 +150,8 @@ def test_uniform_samples(self) -> None: assert_is_uniform(uniform) # Check that we get both ends of the sampled range returned to us. - assert input_space["c"].lower in integer_uniform - assert input_space["c"].upper in integer_uniform + assert input_space['c'].lower in integer_uniform + assert input_space['c'].upper in integer_uniform # integer uniform assert_is_uniform(integer_uniform) @@ -166,13 +165,13 @@ def test_uniform_categorical(self) -> None: assert 35 < counts[1] < 65 def test_weighted_categorical(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') def test_log_int_spaces(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') def test_log_float_spaces(self) -> None: - raise NotImplementedError("subclass must override") + raise NotImplementedError('subclass must override') class TestFlamlConversion(BaseConversion): @@ -185,12 +184,10 @@ class TestFlamlConversion(BaseConversion): def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override] assert isinstance(config_space, dict) assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) - ret: npt.NDArray = np.array( - [domain.sample(size=n_samples) for domain in config_space.values()] - ).T + ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T return ret - def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] + def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] assert isinstance(config_space, dict) ret: List[str] = list(config_space.keys()) return ret @@ -211,9 +208,7 @@ def test_dimensionality(self) -> None: def test_weighted_categorical(self) -> None: np.random.seed(42) input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1]) - ) + input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])) with pytest.raises(ValueError, match="non-uniform"): configspace_to_flaml_space(input_space) @@ -222,9 +217,7 @@ def test_log_int_spaces(self) -> None: np.random.seed(42) # integer is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True) - ) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -242,9 +235,7 @@ def test_log_float_spaces(self) -> None: # continuous is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True) - ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -254,6 +245,6 @@ def test_log_float_spaces(self) -> None: assert_is_log_uniform(float_log_uniform) -if __name__ == "__main__": +if __name__ == '__main__': # For attaching debugger debugging: pytest.main(["-vv", "-k", "test_log_int_spaces", __file__]) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index e6fa12522e..df0e144535 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -28,9 +28,7 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: return pd.DataFrame([dict(config)]) -def normalize_config( - config_space: ConfigurationSpace, config: Union[Configuration, dict] -) -> Configuration: +def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. @@ -51,6 +49,8 @@ def normalize_config( """ cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True) return Configuration( - config_space, - values={key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config)}, + config_space, values={ + key: cs_config[key] + for key in config_space.get_active_hyperparameters(cs_config) + } ) diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index f946f94aa4..2362de7083 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 4d895db315..fed376d1af 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -21,16 +21,15 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -50,54 +49,52 @@ # we return nothing when the file is not available. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass - "flaml": ["flaml[blendsearch]"], - "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0 + 'flaml': ['flaml[blendsearch]'], + 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0 } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', ] setup( version=VERSION, install_requires=[ - "scikit-learn>=1.2", - "joblib>=1.1.1", # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released - "scipy>=1.3.2", - "numpy>=1.24", - "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 - 'pandas >= 2.2.0;python_version>="3.9"', - 'Bottleneck > 1.3.5;python_version>="3.9"', + 'scikit-learn>=1.2', + 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released + 'scipy>=1.3.2', + 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710 + 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - "ConfigSpace>=0.7.1", + 'ConfigSpace>=0.7.1', ], extras_require=extra_requires, **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"), diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index 1725a24ed9..2390554e1e 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -23,7 +23,7 @@ class MlosVizMethod(Enum): """ DABL = "dabl" - AUTO = DABL # use dabl as the current default + AUTO = DABL # use dabl as the current default def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: @@ -39,21 +39,17 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel - mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") -def plot( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - plotter_method: MlosVizMethod = MlosVizMethod.AUTO, - filter_warnings: bool = True, - **kwargs: Any, -) -> None: +def plot(exp_data: Optional[ExperimentData] = None, *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + plotter_method: MlosVizMethod = MlosVizMethod.AUTO, + filter_warnings: bool = True, + **kwargs: Any) -> None: """ Plots the results of the experiment. @@ -85,7 +81,6 @@ def plot( if MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel - mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index d2fc4edad7..15358b0862 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -20,7 +20,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args -_SEABORN_VERS = version("seaborn") +_SEABORN_VERS = version('seaborn') def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: @@ -30,7 +30,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: Note: this only works with non-positional kwargs (e.g., those after a * arg). """ target_kwargs = {} - for kword in target.__kwdefaults__: # or {} # intentionally omitted for now + for kword in target.__kwdefaults__: # or {} # intentionally omitted for now if kword in kwargs: target_kwargs[kword] = kwargs[kword] return target_kwargs @@ -42,19 +42,14 @@ def ignore_plotter_warnings() -> None: adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) - if _SEABORN_VERS <= "0.13.1": - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - module="seaborn", # but actually comes from pandas - message="is_categorical_dtype is deprecated and will be removed in a future version.", - ) + if _SEABORN_VERS <= '0.13.1': + warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas + message="is_categorical_dtype is deprecated and will be removed in a future version.") -def _add_groupby_desc_column( - results_df: pandas.DataFrame, - groupby_columns: Optional[List[str]] = None, -) -> Tuple[pandas.DataFrame, List[str], str]: +def _add_groupby_desc_column(results_df: pandas.DataFrame, + groupby_columns: Optional[List[str]] = None, + ) -> Tuple[pandas.DataFrame, List[str], str]: """ Adds a group descriptor column to the results_df. @@ -72,19 +67,17 @@ def _add_groupby_desc_column( if groupby_columns is None: groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"] groupby_column = ",".join(groupby_columns) - results_df[groupby_column] = ( - results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1) - ) # pylint: disable=unnecessary-lambda + results_df[groupby_column] = results_df[groupby_columns].astype(str).apply( + lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda groupby_columns.append(groupby_column) return (results_df, groupby_columns, groupby_column) -def augment_results_df_with_config_trial_group_stats( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - requested_result_cols: Optional[Iterable[str]] = None, -) -> pandas.DataFrame: +def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + requested_result_cols: Optional[Iterable[str]] = None, + ) -> pandas.DataFrame: # pylint: disable=too-complex """ Add a number of useful statistical measure columns to the results dataframe. @@ -141,46 +134,30 @@ def augment_results_df_with_config_trial_group_stats( raise ValueError(f"Not enough data: {len(results_groups)}") if requested_result_cols is None: - result_cols = set( - col - for col in results_df.columns - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) - ) + result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)) else: - result_cols = set( - col - for col in requested_result_cols - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns - ) - result_cols.update( - set( - ExperimentData.RESULT_COLUMN_PREFIX + col - for col in requested_result_cols - if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns - ) - ) + result_cols = set(col for col in requested_result_cols + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns) + result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols + if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns)) def compute_zscore_for_group_agg( - results_groups_perf: "SeriesGroupBy", - stats_df: pandas.DataFrame, - result_col: str, - agg: Union[Literal["mean"], Literal["var"], Literal["std"]], + results_groups_perf: "SeriesGroupBy", + stats_df: pandas.DataFrame, + result_col: str, + agg: Union[Literal["mean"], Literal["var"], Literal["std"]] ) -> None: - results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? + results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() - stats_df[result_col + f".{agg}_zscore"] = ( - stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"] - ) / stats_df[result_col + f".{agg}_stddev"] - stats_df.drop( - columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True - ) + stats_df[result_col + f".{agg}_zscore"] = \ + (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \ + / stats_df[result_col + f".{agg}_stddev"] + stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True) augmented_results_df = results_df - augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform( - "count" - ) + augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count") for result_col in result_cols: if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX): continue @@ -199,21 +176,20 @@ def compute_zscore_for_group_agg( compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var") quantiles = [0.50, 0.75, 0.90, 0.95, 0.99] - for quantile in quantiles: # TODO: can we do this in one pass? + for quantile in quantiles: # TODO: can we do this in one pass? quantile_col = f"{result_col}.p{int(quantile * 100)}" stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile) augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1) return augmented_results_df -def limit_top_n_configs( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - top_n_configs: int = 10, - method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", -) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: +def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + top_n_configs: int = 10, + method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", + ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ Utility function to process the results and determine the best performing @@ -243,9 +219,7 @@ def limit_top_n_configs( raise ValueError(f"Invalid method: {method}") # Prepare the orderby columns. - (results_df, objs_cols) = expand_results_data_args( - exp_data, results_df=results_df, objectives=objectives - ) + (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives) assert isinstance(results_df, pandas.DataFrame) # Augment the results dataframe with some useful stats. @@ -258,17 +232,13 @@ def limit_top_n_configs( # results_df is not None and is in fact a DataFrame, so we periodically assert # it in this func for now. assert results_df is not None - orderby_cols: Dict[str, bool] = { - obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items() - } + orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()} config_id_col = "tunable_config_id" - group_id_col = "tunable_config_trial_group_id" # first trial_id per config group + group_id_col = "tunable_config_trial_group_id" # first trial_id per config group trial_id_col = "trial_id" - default_config_id = ( - results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id - ) + default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id assert default_config_id is not None, "Failed to determine default config id." # Filter out configs whose variance is too large. @@ -280,18 +250,16 @@ def limit_top_n_configs( singletons_mask = results_df["tunable_config_trial_group_size"] == 1 else: singletons_mask = results_df["tunable_config_trial_group_size"] > 1 - results_df = results_df.loc[ - ( - (results_df[f"{obj_col}.var_zscore"].abs() < 2) - | (singletons_mask) - | (results_df[config_id_col] == default_config_id) - ) - ] + results_df = results_df.loc[( + (results_df[f"{obj_col}.var_zscore"].abs() < 2) + | (singletons_mask) + | (results_df[config_id_col] == default_config_id) + )] assert results_df is not None # Also, filter results that are worse than the default. default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id] - for orderby_col, ascending in orderby_cols.items(): + for (orderby_col, ascending) in orderby_cols.items(): default_vals = default_config_results_df[orderby_col].unique() assert len(default_vals) == 1 default_val = default_vals[0] @@ -303,38 +271,29 @@ def limit_top_n_configs( # Now regroup and filter to the top-N configs by their group performance dimensions. assert results_df is not None - group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[ - orderby_cols.keys() - ] - top_n_config_ids: List[int] = ( - group_results_df.sort_values( - by=list(orderby_cols.keys()), ascending=list(orderby_cols.values()) - ) - .head(top_n_configs) - .index.tolist() - ) + group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()] + top_n_config_ids: List[int] = group_results_df.sort_values( + by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist() # Remove the default config if it's included. We'll add it back later. if default_config_id in top_n_config_ids: top_n_config_ids.remove(default_config_id) # Get just the top-n config results. # Sort by the group ids. - top_n_config_results_df = results_df.loc[ - (results_df[config_id_col].isin(top_n_config_ids)) - ].sort_values([group_id_col, config_id_col, trial_id_col]) + top_n_config_results_df = results_df.loc[( + results_df[config_id_col].isin(top_n_config_ids) + )].sort_values([group_id_col, config_id_col, trial_id_col]) # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) - top_n_config_results_df = pandas.concat( - [default_config_results_df, top_n_config_results_df], axis=0 - ) + top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0) return (top_n_config_results_df, top_n_config_ids, orderby_cols) def plot_optimizer_trends( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, ) -> None: """ Plots the optimizer trends for the Experiment. @@ -353,16 +312,12 @@ def plot_optimizer_trends( (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) (results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df) - for objective_column, ascending in obj_cols.items(): + for (objective_column, ascending) in obj_cols.items(): incumbent_column = objective_column + ".incumbent" # Determine the mean of each config trial group to match the box plots. - group_results_df = ( - results_df.groupby(groupby_columns)[objective_column] - .mean() - .reset_index() - .sort_values(groupby_columns) - ) + group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\ + .reset_index().sort_values(groupby_columns) # # Note: technically the optimizer (usually) uses the *first* result for a # given config trial group before moving on to a new config (x-axis), so @@ -400,29 +355,24 @@ def plot_optimizer_trends( ax=axis, ) - plt.yscale("log") + plt.yscale('log') plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")) plt.xlabel("Config Trial Group ID, Config ID") plt.xticks(rotation=90, fontsize=8) - plt.title( - "Optimizer Trends for Experiment: " + exp_data.experiment_id - if exp_data is not None - else "" - ) + plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "") plt.grid() plt.show() # type: ignore[no-untyped-call] -def plot_top_n_configs( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - with_scatter_plot: bool = False, - **kwargs: Any, -) -> None: +def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + with_scatter_plot: bool = False, + **kwargs: Any, + ) -> None: # pylint: disable=too-many-locals """ Plots the top-N configs along with the default config for the given ExperimentData. @@ -450,16 +400,12 @@ def plot_top_n_configs( top_n_config_args["results_df"] = results_df if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives - (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( - exp_data=exp_data, **top_n_config_args - ) + (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args) - (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column( - top_n_config_results_df - ) + (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 - for orderby_col, ascending in orderby_cols.items(): + for (orderby_col, ascending) in orderby_cols.items(): opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "") (_fig, axis) = plt.subplots() sns.violinplot( @@ -479,12 +425,12 @@ def plot_top_n_configs( plt.grid() (xticks, xlabels) = plt.xticks() # default should be in the first position based on top_n_configs() return - xlabels[0] = "default" # type: ignore[call-overload] - plt.xticks(xticks, xlabels) # type: ignore[arg-type] + xlabels[0] = "default" # type: ignore[call-overload] + plt.xticks(xticks, xlabels) # type: ignore[arg-type] plt.xlabel("Config Trial Group, Config ID") plt.xticks(rotation=90) plt.ylabel(opt_tgt) - plt.yscale("log") + plt.yscale('log') extra_title = "(lower is better)" if ascending else "(lower is better)" plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}") plt.show() # type: ignore[no-untyped-call] diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index beeba3248f..504486a58c 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -15,12 +15,10 @@ from mlos_viz.util import expand_results_data_args -def plot( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, -) -> None: +def plot(exp_data: Optional[ExperimentData] = None, *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + ) -> None: """ Plots the Experiment results data using dabl. @@ -46,45 +44,17 @@ def ignore_plotter_warnings() -> None: """ # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="Could not infer format" - ) - warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers" - ) - warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated" - ) - warnings.filterwarnings( - "ignore", - module="dabl", - category=UserWarning, - message="Missing values in target_col have been removed for regression", - ) + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated") + warnings.filterwarnings("ignore", module="dabl", category=UserWarning, + message="Missing values in target_col have been removed for regression") from sklearn.exceptions import UndefinedMetricWarning - - warnings.filterwarnings( - "ignore", - module="sklearn", - category=UndefinedMetricWarning, - message="Recall is ill-defined", - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - message="is_categorical_dtype is deprecated and will be removed in a future version.", - ) - warnings.filterwarnings( - "ignore", - category=DeprecationWarning, - module="sklearn", - message="is_sparse is deprecated and will be removed in a future version.", - ) + warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined") + warnings.filterwarnings("ignore", category=DeprecationWarning, + message="is_categorical_dtype is deprecated and will be removed in a future version.") + warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn", + message="is_sparse is deprecated and will be removed in a future version.") from matplotlib._api.deprecation import MatplotlibDeprecationWarning - - warnings.filterwarnings( - "ignore", - category=MatplotlibDeprecationWarning, - module="dabl", - message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", - ) + warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl", + message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed") diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index e5528f9875..06ac4a7664 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -30,5 +30,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) -> warnings.simplefilter("error") random.seed(42) plot(exp_data, filter_warnings=True) - assert mock_show.call_count >= 2 # from the two base plots and anything dabl did - assert mock_boxplot.call_count >= 1 # from anything dabl did + assert mock_show.call_count >= 2 # from the two base plots and anything dabl did + assert mock_boxplot.call_count >= 1 # from anything dabl did diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 8f426810f8..744fe28648 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -49,14 +49,11 @@ def expand_results_data_args( raise ValueError("Must provide either exp_data or both results_df and objectives.") objectives = exp_data.objectives objs_cols: Dict[str, bool] = {} - for opt_tgt, opt_dir in objectives.items(): + for (opt_tgt, opt_dir) in objectives.items(): if opt_dir not in ["min", "max"]: raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}") ascending = opt_dir == "min" - if ( - opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) - and opt_tgt in results_df.columns - ): + if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns: objs_cols[opt_tgt] = ascending elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns: objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index d418ae43c7..607c7cc014 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -7,7 +7,7 @@ """ # NOTE: This should be managed by bumpversion. -VERSION = "0.5.1" +VERSION = '0.5.1' if __name__ == "__main__": print(VERSION) diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 638a28469a..98d12598e1 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -21,16 +21,15 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns["VERSION"] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns['VERSION'] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - - version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) + version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -48,22 +47,22 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, "README.md") + readme_path = os.path.join(pkg_dir, 'README.md') if not os.path.isfile(readme_path): return { - "long_description": "missing", + 'long_description': 'missing', } - jsonc_re = re.compile(r"```jsonc") - link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") - with open(readme_path, mode="r", encoding="utf-8") as readme_fh: + jsonc_re = re.compile(r'```jsonc') + link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') + with open(readme_path, mode='r', encoding='utf-8') as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r"```json", line) for line in lines] + lines = [jsonc_re.sub(r'```json', line) for line in lines] return { - "long_description": "".join(lines), - "long_description_content_type": "text/markdown", + 'long_description': ''.join(lines), + 'long_description_content_type': 'text/markdown', } @@ -71,23 +70,23 @@ def _get_long_desc_from_readme(base_url: str) -> dict: # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires["full"] = list(set(chain(*extra_requires.values()))) +extra_requires['full'] = list(set(chain(*extra_requires.values()))) -extra_requires["full-tests"] = extra_requires["full"] + [ - "pytest", - "pytest-forked", - "pytest-xdist", - "pytest-cov", - "pytest-local-badge", +extra_requires['full-tests'] = extra_requires['full'] + [ + 'pytest', + 'pytest-forked', + 'pytest-xdist', + 'pytest-cov', + 'pytest-local-badge', ] setup( version=VERSION, install_requires=[ - "mlos-bench==" + VERSION, - "dabl>=0.2.6", - "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341 + 'mlos-bench==' + VERSION, + 'dabl>=0.2.6', + 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341 ], extras_require=extra_requires, - **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_viz"), + **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'), ) From 648ca631cb99af0bee2f924e4894e57fe63884c9 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:50:40 +0000 Subject: [PATCH 17/54] Revert "decrease line length to a longer line than black defaults to but still within the pep8 guidelines" This reverts commit 857d694d2a9d7756cc80412f10dca5990531f19b. --- .editorconfig | 3 --- .pylintrc | 2 +- pyproject.toml | 2 +- setup.cfg | 2 +- 4 files changed, 3 insertions(+), 6 deletions(-) diff --git a/.editorconfig b/.editorconfig index b31e722644..e984d47595 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,9 +12,6 @@ charset = utf-8 # Note: this is not currently supported by all editors or their editorconfig plugins. max_line_length = 132 -[{*.py,*.pyi}] -max_line_length = 99 - # Makefiles need tab indentation [{Makefile,*.mk}] indent_style = tab diff --git a/.pylintrc b/.pylintrc index c6c512ecb7..e686070503 100644 --- a/.pylintrc +++ b/.pylintrc @@ -35,7 +35,7 @@ load-plugins= [FORMAT] # Maximum number of characters on a single line. -max-line-length=99 +max-line-length=132 [MESSAGE CONTROL] disable= diff --git a/pyproject.toml b/pyproject.toml index 6673321b6c..26409fa408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 99 +line-length = 132 target-version = ["py38", "py39", "py310", "py311", "py312"] include = '\.pyi?$' diff --git a/setup.cfg b/setup.cfg index 1f3e7a39d5..9d09b3356d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ match = .+(? Date: Mon, 8 Jul 2024 16:54:31 +0000 Subject: [PATCH 18/54] isort change fixups --- .../mlos_bench/tests/optimizers/grid_search_optimizer_test.py | 1 + mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 9e9ce25d6f..cfecb02058 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -20,6 +20,7 @@ # pylint: disable=redefined-outer-name + @pytest.fixture def grid_search_tunables_config() -> dict: """ diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index 58bb0368b1..0181957cd0 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -14,6 +14,7 @@ # Note: these test do *not* check the ConfigSpace conversions for those same Tunables. # That is checked indirectly via grid_search_optimizer_test.py + def test_tunable_int_size_props() -> None: """Test tunable int size properties""" tunable = Tunable( From 23e218b4b18c4d8d4b4f2787ecde7932c2685dc6 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 16:56:08 +0000 Subject: [PATCH 19/54] minor tweak for ripgrep so that --hidden option works to search dot files without looking at the .git/ directory --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 157dba7a4d..471d653344 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Ignore git directory (ripgrep) +.git/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] From ed6d7e021b8b6ef1a38041032a6efc7fda3ae840 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:18:05 +0000 Subject: [PATCH 20/54] add docformatter to aid black --- conda-envs/mlos-3.10.yml | 1 + conda-envs/mlos-3.11.yml | 1 + conda-envs/mlos-3.8.yml | 1 + conda-envs/mlos-3.9.yml | 1 + conda-envs/mlos-windows.yml | 1 + conda-envs/mlos.yml | 1 + pyproject.toml | 6 ++++++ 7 files changed, 12 insertions(+) diff --git a/conda-envs/mlos-3.10.yml b/conda-envs/mlos-3.10.yml index 4614b28a78..75bf64c5bf 100644 --- a/conda-envs/mlos-3.10.yml +++ b/conda-envs/mlos-3.10.yml @@ -28,6 +28,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/conda-envs/mlos-3.11.yml b/conda-envs/mlos-3.11.yml index 9680186660..6443c7a308 100644 --- a/conda-envs/mlos-3.11.yml +++ b/conda-envs/mlos-3.11.yml @@ -28,6 +28,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/conda-envs/mlos-3.8.yml b/conda-envs/mlos-3.8.yml index 1cfb0e18d2..8b79aad2c4 100644 --- a/conda-envs/mlos-3.8.yml +++ b/conda-envs/mlos-3.8.yml @@ -28,6 +28,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/conda-envs/mlos-3.9.yml b/conda-envs/mlos-3.9.yml index 75cee3baee..88b384a428 100644 --- a/conda-envs/mlos-3.9.yml +++ b/conda-envs/mlos-3.9.yml @@ -28,6 +28,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/conda-envs/mlos-windows.yml b/conda-envs/mlos-windows.yml index 190c2699e5..1287247641 100644 --- a/conda-envs/mlos-windows.yml +++ b/conda-envs/mlos-windows.yml @@ -31,6 +31,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml index 51ce8077a8..a65633fcfe 100644 --- a/conda-envs/mlos.yml +++ b/conda-envs/mlos.yml @@ -27,6 +27,7 @@ dependencies: - bump2version - check-jsonschema - isort + - docformatter - licenseheaders - mypy - pandas-stubs diff --git a/pyproject.toml b/pyproject.toml index 26409fa408..0bebbed6b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,3 +7,9 @@ include = '\.pyi?$' profile = "black" py_version = 311 src_paths = ["mlos_core", "mlos_bench", "mlos_viz"] + +[tool.docformatter] +recursive = true +black = true +style = "numpy" +blank = true From b28caa7d05fa38a3d8ccff3d7645fa306eac6dcc Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:27:35 +0000 Subject: [PATCH 21/54] enable docformatter for docstrings as well --- Makefile | 78 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index d7ca684bc7..6a1ee8ddcc 100644 --- a/Makefile +++ b/Makefile @@ -69,9 +69,9 @@ ifneq (,$(filter format,$(MAKECMDGOALS))) endif build/format.${CONDA_ENV_NAME}.build-stamp: build/licenseheaders.${CONDA_ENV_NAME}.build-stamp -# TODO: enable isort and black formatters -#build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp -#build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp +build/format.${CONDA_ENV_NAME}.build-stamp: build/isort.${CONDA_ENV_NAME}.build-stamp +build/format.${CONDA_ENV_NAME}.build-stamp: build/black.${CONDA_ENV_NAME}.build-stamp +build/format.${CONDA_ENV_NAME}.build-stamp: build/docformatter.${CONDA_ENV_NAME}.build-stamp build/format.${CONDA_ENV_NAME}.build-stamp: touch $@ @@ -111,8 +111,8 @@ build/isort.${CONDA_ENV_NAME}.build-stamp: # NOTE: when using pattern rules (involving %) we can only add one line of # prerequisities, so we use this pattern to compose the list as variables. -# Both isort and licenseheaders alter files, so only run one at a time, by -# making licenseheaders an order-only prerequisite. +# black, licenseheaders, isort, and docformatter all alter files, so only run +# one at a time, by adding prerequisites, but only as necessary. ISORT_COMMON_PREREQS := ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS))) ISORT_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp @@ -142,8 +142,8 @@ build/black.${CONDA_ENV_NAME}.build-stamp: build/black.mlos_viz.${CONDA_ENV_NAME build/black.${CONDA_ENV_NAME}.build-stamp: touch $@ -# Both black, licenseheaders, and isort all alter files, so only run one at a time, by -# making licenseheaders and isort an order-only prerequisite. +# black, licenseheaders, isort, and docformatter all alter files, so only run +# one at a time, by adding prerequisites, but only as necessary. BLACK_COMMON_PREREQS := ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS))) BLACK_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp @@ -163,8 +163,46 @@ build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS) conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$?) touch $@ +.PHONY: docformatter +docformatter: build/docformatter.${CONDA_ENV_NAME}.build-stamp + +ifneq (,$(filter docformatter,$(MAKECMDGOALS))) + FORMAT_PREREQS += build/docformatter.${CONDA_ENV_NAME}.build-stamp +endif + +build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp +build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp +build/docformatter.${CONDA_ENV_NAME}.build-stamp: build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp +build/docformatter.${CONDA_ENV_NAME}.build-stamp: + touch $@ + +# black, licenseheaders, isort, and docformatter all alter files, so only run +# one at a time, by adding prerequisites, but only as necessary. +DOCFORMATTER_COMMON_PREREQS := +ifneq (,$(filter format licenseheaders,$(MAKECMDGOALS))) +DOCFORMATTER_COMMON_PREREQS += build/licenseheaders.${CONDA_ENV_NAME}.build-stamp +endif +ifneq (,$(filter format isort,$(MAKECMDGOALS))) +DOCFORMATTER_COMMON_PREREQS += build/isort.${CONDA_ENV_NAME}.build-stamp +endif +ifneq (,$(filter format black,$(MAKECMDGOALS))) +DOCFORMATTER_COMMON_PREREQS += build/black.${CONDA_ENV_NAME}.build-stamp +endif +DOCFORMATTER_COMMON_PREREQS += build/conda-env.${CONDA_ENV_NAME}.build-stamp +DOCFORMATTER_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) + +build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES) +build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES) +build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES) + +build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS) + # Reformat python file docstrings with docformatter. + conda run -n ${CONDA_ENV_NAME} docformatter $(filter %.py,$?) + touch $@ + + .PHONY: check -check: isort-check black-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check +check: isort-check black-check docformatter-check pycodestyle pydocstyle pylint mypy # cspell markdown-link-check .PHONY: black-check black-check: build/black-check.mlos_core.${CONDA_ENV_NAME}.build-stamp @@ -186,6 +224,26 @@ build/black-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS) conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$?) touch $@ +.PHONY: docformatter-check +docformatter-check: build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp +docformatter-check: build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp +docformatter-check: build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp + +# Make sure docformatter format rules run before docformatter-check rules. +build/docformatter-check.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_FILES) +build/docformatter-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES) +build/docformatter-check.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES) + +BLACK_CHECK_COMMON_PREREQS := build/conda-env.${CONDA_ENV_NAME}.build-stamp +BLACK_CHECK_COMMON_PREREQS += $(FORMAT_PREREQS) +BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) + +build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS) + # Check for import sort order. + # Note: if this fails use "make format" or "make docformatter" to fix it. + conda run -n ${CONDA_ENV_NAME} docformatter --check $(filter %.py,$?) + touch $@ + .PHONY: isort-check isort-check: build/isort-check.mlos_core.${CONDA_ENV_NAME}.build-stamp isort-check: build/isort-check.mlos_bench.${CONDA_ENV_NAME}.build-stamp @@ -723,6 +781,8 @@ clean-doc: clean-format: rm -f build/black.${CONDA_ENV_NAME}.build-stamp rm -f build/black.mlos_*.${CONDA_ENV_NAME}.build-stamp + rm -f build/docformatter.${CONDA_ENV_NAME}.build-stamp + rm -f build/docformatter.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/isort.${CONDA_ENV_NAME}.build-stamp rm -f build/isort.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/licenseheaders.${CONDA_ENV_NAME}.build-stamp @@ -737,6 +797,8 @@ clean-check: rm -f build/black-check.build-stamp rm -f build/black-check.${CONDA_ENV_NAME}.build-stamp rm -f build/black-check.mlos_*.${CONDA_ENV_NAME}.build-stamp + rm -f build/docformatter-check.${CONDA_ENV_NAME}.build-stamp + rm -f build/docformatter-check.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/isort-check.${CONDA_ENV_NAME}.build-stamp rm -f build/isort-check.mlos_*.${CONDA_ENV_NAME}.build-stamp rm -f build/pycodestyle.build-stamp From 41015a7729bb05132ae4cc93f5630d853b04fc9b Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:33:12 +0000 Subject: [PATCH 22/54] fixups --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 6a1ee8ddcc..1f0ad181e3 100644 --- a/Makefile +++ b/Makefile @@ -197,7 +197,7 @@ build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FIL build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS) # Reformat python file docstrings with docformatter. - conda run -n ${CONDA_ENV_NAME} docformatter $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$?) touch $@ @@ -241,7 +241,7 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS) # Check for import sort order. # Note: if this fails use "make format" or "make docformatter" to fix it. - conda run -n ${CONDA_ENV_NAME} docformatter --check $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$?) touch $@ .PHONY: isort-check From d2f9fa62aa1d89fdf84f2d48f59bbc8779983018 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:33:19 +0000 Subject: [PATCH 23/54] tweaks --- pyproject.toml | 1 - setup.cfg | 4 +--- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0bebbed6b5..0d24377de9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,4 +12,3 @@ src_paths = ["mlos_core", "mlos_bench", "mlos_viz"] recursive = true black = true style = "numpy" -blank = true diff --git a/setup.cfg b/setup.cfg index 9d09b3356d..4ea8ea7c93 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,12 +16,10 @@ statistics = True # D102: Missing docstring in public method (Avoids inheritence bug. Force checked in .pylintrc instead.) # D105: Missing docstring in magic method # D107: Missing docstring in __init__ -# D200: One-line docstring should fit on one line with quotes # D401: First line should be in imperative mood # We have many docstrings that are too long to fit on one line, so we ignore both of these two rules: # D205: 1 blank line required between summary line and description -# D400: First line should end with a period -add_ignore = D102,D105,D107,D200,D401,D205,D400 +add_ignore = D102,D105,D107,D401,D205 match = .+(? Date: Mon, 8 Jul 2024 17:35:57 +0000 Subject: [PATCH 24/54] fixups --- Makefile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 1f0ad181e3..4c0d544459 100644 --- a/Makefile +++ b/Makefile @@ -197,7 +197,8 @@ build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FIL build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS) # Reformat python file docstrings with docformatter. - conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$?) || true + conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$?) touch $@ From dd1d30d7e4570352f3d99ed0d5381cbe41ee0257 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:41:33 +0000 Subject: [PATCH 25/54] doc style tweaks --- pyproject.toml | 2 ++ setup.cfg | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0d24377de9..16484d0aba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,5 @@ src_paths = ["mlos_core", "mlos_bench", "mlos_viz"] recursive = true black = true style = "numpy" +pre-summary-newline = true +close-quotes-on-newline = true diff --git a/setup.cfg b/setup.cfg index 4ea8ea7c93..d0c2ee37c7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,10 +16,11 @@ statistics = True # D102: Missing docstring in public method (Avoids inheritence bug. Force checked in .pylintrc instead.) # D105: Missing docstring in magic method # D107: Missing docstring in __init__ +# D200: One-line docstring should fit on one line with quotes # D401: First line should be in imperative mood # We have many docstrings that are too long to fit on one line, so we ignore both of these two rules: # D205: 1 blank line required between summary line and description -add_ignore = D102,D105,D107,D401,D205 +add_ignore = D102,D105,D107,D200,D401,D205 match = .+(? Date: Mon, 8 Jul 2024 17:48:03 +0000 Subject: [PATCH 26/54] fixups - run on all files --- Makefile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 4c0d544459..15eaa9b74c 100644 --- a/Makefile +++ b/Makefile @@ -126,7 +126,7 @@ build/isort.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES) build/isort.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_COMMON_PREREQS) # Reformat python file imports with isort. - conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} isort --verbose --only-modified --atomic -j0 $(filter %.py,$+) touch $@ .PHONY: black @@ -160,7 +160,7 @@ build/black.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES) build/black.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_COMMON_PREREQS) # Reformat python files with black. - conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} black $(filter %.py,$+) touch $@ .PHONY: docformatter @@ -197,8 +197,8 @@ build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FIL build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS) # Reformat python file docstrings with docformatter. - conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$?) || true - conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$+) || true + conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+) touch $@ @@ -222,7 +222,7 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) build/black-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS) # Check for import sort order. # Note: if this fails use "make format" or "make black" to fix it. - conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} black --verbose --check --diff $(filter %.py,$+) touch $@ .PHONY: docformatter-check @@ -242,7 +242,7 @@ BLACK_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) build/docformatter-check.%.${CONDA_ENV_NAME}.build-stamp: $(BLACK_CHECK_COMMON_PREREQS) # Check for import sort order. # Note: if this fails use "make format" or "make docformatter" to fix it. - conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} docformatter --check --diff $(filter %.py,$+) touch $@ .PHONY: isort-check @@ -261,7 +261,7 @@ ISORT_CHECK_COMMON_PREREQS += $(MLOS_GLOBAL_CONF_FILES) build/isort-check.%.${CONDA_ENV_NAME}.build-stamp: $(ISORT_CHECK_COMMON_PREREQS) # Note: if this fails use "make format" or "make isort" to fix it. - conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$?) + conda run -n ${CONDA_ENV_NAME} isort --only-modified --check --diff -j0 $(filter %.py,$+) touch $@ .PHONY: pycodestyle From bb2f4daaa6bc515c11731b0aba6bb8de8a01a633 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 17:58:43 +0000 Subject: [PATCH 27/54] tweaks --- setup.cfg | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index d0c2ee37c7..88fd64a8e2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,14 +13,14 @@ show-source = True statistics = True [pydocstyle] -# D102: Missing docstring in public method (Avoids inheritence bug. Force checked in .pylintrc instead.) +# D102: Missing docstring in public method (Avoids inheritence bug. Force checked in pylint instead.) # D105: Missing docstring in magic method # D107: Missing docstring in __init__ -# D200: One-line docstring should fit on one line with quotes # D401: First line should be in imperative mood # We have many docstrings that are too long to fit on one line, so we ignore both of these two rules: # D205: 1 blank line required between summary line and description -add_ignore = D102,D105,D107,D200,D401,D205 +# D400: First line should end with a period +add_ignore = D102,D105,D107,D401,D205,D400 match = .+(? Date: Mon, 8 Jul 2024 18:06:35 +0000 Subject: [PATCH 28/54] docformatter applied --- mlos_bench/mlos_bench/__init__.py | 5 +- mlos_bench/mlos_bench/config/__init__.py | 4 +- .../fio/scripts/local/process_fio_results.py | 12 +- .../scripts/local/process_redis_results.py | 8 +- .../boot/scripts/local/create_new_grub_cfg.py | 4 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 20 ++-- mlos_bench/mlos_bench/dict_templater.py | 12 +- .../mlos_bench/environments/__init__.py | 5 +- .../environments/base_environment.py | 51 ++++---- .../mlos_bench/environments/composite_env.py | 26 ++--- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 24 ++-- .../environments/local/local_fileshare_env.py | 20 ++-- .../mlos_bench/environments/mock_env.py | 9 +- .../environments/remote/__init__.py | 4 +- .../environments/remote/host_env.py | 12 +- .../environments/remote/network_env.py | 8 +- .../mlos_bench/environments/remote/os_env.py | 12 +- .../environments/remote/remote_env.py | 8 +- .../environments/remote/saas_env.py | 8 +- .../mlos_bench/environments/remote/vm_env.py | 4 +- .../mlos_bench/environments/script_env.py | 8 +- mlos_bench/mlos_bench/environments/status.py | 41 ++----- mlos_bench/mlos_bench/event_loop_context.py | 34 +++--- mlos_bench/mlos_bench/launcher.py | 50 ++++---- mlos_bench/mlos_bench/optimizers/__init__.py | 4 +- .../mlos_bench/optimizers/base_optimizer.py | 56 ++++----- .../optimizers/convert_configspace.py | 14 +-- .../optimizers/grid_search_optimizer.py | 12 +- .../optimizers/mlos_core_optimizer.py | 16 +-- .../mlos_bench/optimizers/mock_optimizer.py | 12 +- .../optimizers/one_shot_optimizer.py | 5 +- .../optimizers/track_best_optimizer.py | 12 +- mlos_bench/mlos_bench/os_environ.py | 4 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 59 +++++----- .../mlos_bench/schedulers/sync_scheduler.py | 16 +-- mlos_bench/mlos_bench/services/__init__.py | 4 +- .../mlos_bench/services/base_fileshare.py | 8 +- .../mlos_bench/services/base_service.py | 32 +++-- .../mlos_bench/services/config_persistence.py | 31 +++-- .../mlos_bench/services/local/__init__.py | 4 +- .../mlos_bench/services/local/local_exec.py | 22 ++-- .../services/local/temp_dir_context.py | 16 ++- .../services/remote/azure/__init__.py | 4 +- .../services/remote/azure/azure_auth.py | 16 +-- .../remote/azure/azure_deployment_services.py | 31 ++--- .../services/remote/azure/azure_fileshare.py | 19 ++- .../remote/azure/azure_network_services.py | 8 +- .../services/remote/azure/azure_saas.py | 16 +-- .../remote/azure/azure_vm_services.py | 11 +- .../services/remote/ssh/ssh_fileshare.py | 10 +- .../services/remote/ssh/ssh_host_service.py | 15 +-- .../services/remote/ssh/ssh_service.py | 26 ++--- .../mlos_bench/services/types/__init__.py | 4 +- .../services/types/authenticator_type.py | 8 +- .../services/types/config_loader_type.py | 21 ++-- .../services/types/fileshare_type.py | 8 +- .../services/types/host_ops_type.py | 8 +- .../services/types/host_provisioner_type.py | 11 +- .../services/types/local_exec_type.py | 15 ++- .../types/network_provisioner_type.py | 8 +- .../mlos_bench/services/types/os_ops_type.py | 12 +- .../services/types/remote_config_type.py | 8 +- .../services/types/remote_exec_type.py | 10 +- .../services/types/vm_provisioner_type.py | 12 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 15 +-- mlos_bench/mlos_bench/storage/base_storage.py | 63 +++++----- .../mlos_bench/storage/base_trial_data.py | 36 ++---- .../storage/base_tunable_config_data.py | 8 +- .../base_tunable_config_trial_group_data.py | 29 ++--- mlos_bench/mlos_bench/storage/sql/__init__.py | 4 +- mlos_bench/mlos_bench/storage/sql/common.py | 14 +-- .../mlos_bench/storage/sql/experiment.py | 14 +-- .../mlos_bench/storage/sql/experiment_data.py | 7 +- mlos_bench/mlos_bench/storage/sql/schema.py | 20 +--- mlos_bench/mlos_bench/storage/sql/storage.py | 8 +- mlos_bench/mlos_bench/storage/sql/trial.py | 9 +- .../mlos_bench/storage/sql/trial_data.py | 20 +--- .../storage/sql/tunable_config_data.py | 4 +- .../sql/tunable_config_trial_group_data.py | 15 +-- .../mlos_bench/storage/storage_factory.py | 4 +- mlos_bench/mlos_bench/storage/util.py | 4 +- mlos_bench/mlos_bench/tests/__init__.py | 21 ++-- .../mlos_bench/tests/config/__init__.py | 7 +- .../cli/test_load_cli_config_examples.py | 4 +- .../mlos_bench/tests/config/conftest.py | 4 +- .../test_load_environment_config_examples.py | 4 +- .../test_load_global_config_examples.py | 4 +- .../test_load_optimizer_config_examples.py | 4 +- .../tests/config/schemas/__init__.py | 29 ++--- .../config/schemas/cli/test_cli_schemas.py | 12 +- .../environments/test_environment_schemas.py | 16 +-- .../schemas/globals/test_globals_schemas.py | 8 +- .../optimizers/test_optimizer_schemas.py | 24 ++-- .../schedulers/test_scheduler_schemas.py | 16 +-- .../schemas/services/test_services_schemas.py | 16 +-- .../schemas/storage/test_storage_schemas.py | 16 +-- .../test_tunable_params_schemas.py | 8 +- .../test_tunable_values_schemas.py | 8 +- .../test_load_service_config_examples.py | 4 +- .../test_load_storage_config_examples.py | 4 +- mlos_bench/mlos_bench/tests/conftest.py | 24 ++-- .../mlos_bench/tests/dict_templater_test.py | 20 +--- .../mlos_bench/tests/environments/__init__.py | 8 +- .../tests/environments/base_env_test.py | 24 +--- .../composite_env_service_test.py | 12 +- .../tests/environments/composite_env_test.py | 28 ++--- .../environments/include_tunables_test.py | 36 ++---- .../tests/environments/local/__init__.py | 1 + .../local/composite_local_env_test.py | 9 +- .../local/local_env_stdout_test.py | 16 +-- .../local/local_env_telemetry_test.py | 24 +--- .../environments/local/local_env_test.py | 20 +--- .../environments/local/local_env_vars_test.py | 16 +-- .../local/local_fileshare_env_test.py | 17 +-- .../tests/environments/mock_env_test.py | 20 +--- .../tests/environments/remote/__init__.py | 4 +- .../tests/environments/remote/conftest.py | 4 +- .../tests/environments/remote/test_ssh_env.py | 8 +- .../tests/event_loop_context_test.py | 5 +- .../tests/launcher_in_process_test.py | 8 +- .../tests/launcher_parse_args_test.py | 8 +- .../mlos_bench/tests/launcher_run_test.py | 32 ++--- .../mlos_bench/tests/optimizers/__init__.py | 1 + .../mlos_bench/tests/optimizers/conftest.py | 36 ++---- .../optimizers/grid_search_optimizer_test.py | 32 ++--- .../tests/optimizers/llamatune_opt_test.py | 16 +-- .../tests/optimizers/mlos_core_opt_df_test.py | 12 +- .../optimizers/mlos_core_opt_smac_test.py | 32 ++--- .../tests/optimizers/mock_opt_test.py | 32 ++--- .../optimizers/opt_bulk_register_test.py | 49 +++----- .../optimizers/toy_optimization_loop_test.py | 24 +--- .../mlos_bench/tests/services/__init__.py | 1 + .../tests/services/config_persistence_test.py | 32 ++--- .../tests/services/local/__init__.py | 1 + .../services/local/local_exec_python_test.py | 12 +- .../tests/services/local/local_exec_test.py | 52 +++------ .../tests/services/local/mock/__init__.py | 4 +- .../local/mock/mock_local_exec_service.py | 8 +- .../mlos_bench/tests/services/mock_service.py | 12 +- .../tests/services/remote/__init__.py | 1 + .../tests/services/remote/azure/__init__.py | 8 +- .../remote/azure/azure_fileshare_test.py | 6 +- .../azure/azure_network_services_test.py | 16 +-- .../remote/azure/azure_vm_services_test.py | 48 ++------ .../tests/services/remote/azure/conftest.py | 28 ++--- .../tests/services/remote/mock/__init__.py | 4 +- .../services/remote/mock/mock_auth_service.py | 8 +- .../remote/mock/mock_fileshare_service.py | 16 +-- .../remote/mock/mock_network_service.py | 8 +- .../remote/mock/mock_remote_exec_service.py | 8 +- .../services/remote/mock/mock_vm_service.py | 8 +- .../tests/services/remote/ssh/__init__.py | 9 +- .../tests/services/remote/ssh/conftest.py | 4 +- .../tests/services/remote/ssh/fixtures.py | 11 +- .../services/remote/ssh/test_ssh_fileshare.py | 8 +- .../remote/ssh/test_ssh_host_service.py | 16 +-- .../services/remote/ssh/test_ssh_service.py | 5 +- .../test_service_method_registering.py | 8 +- .../mlos_bench/tests/storage/__init__.py | 4 +- .../mlos_bench/tests/storage/conftest.py | 4 +- .../tests/storage/exp_context_test.py | 8 +- .../mlos_bench/tests/storage/exp_data_test.py | 24 ++-- .../mlos_bench/tests/storage/exp_load_test.py | 37 ++---- .../mlos_bench/tests/storage/sql/__init__.py | 4 +- .../mlos_bench/tests/storage/sql/fixtures.py | 42 +++---- .../tests/storage/trial_config_test.py | 17 +-- .../tests/storage/trial_data_test.py | 8 +- .../tests/storage/trial_schedule_test.py | 12 +- .../tests/storage/trial_telemetry_test.py | 16 +-- .../tests/storage/tunable_config_data_test.py | 19 +-- .../tunable_config_trial_group_data_test.py | 4 +- .../mlos_bench/tests/test_with_alt_tz.py | 8 +- .../tests/tunable_groups_fixtures.py | 8 +- .../mlos_bench/tests/tunables/__init__.py | 1 + .../mlos_bench/tests/tunables/conftest.py | 4 +- .../tunables/test_empty_tunable_group.py | 12 +- .../tunables/test_tunable_categoricals.py | 9 +- .../tunables/test_tunables_size_props.py | 14 +-- .../tests/tunables/tunable_accessors_test.py | 20 +--- .../tests/tunables/tunable_comparison_test.py | 36 ++---- .../tests/tunables/tunable_definition_test.py | 109 +++++------------- .../tunables/tunable_distributions_test.py | 24 +--- .../tunables/tunable_group_indexing_test.py | 8 +- .../tunables/tunable_group_subgroup_test.py | 8 +- .../tunables/tunable_group_update_test.py | 20 +--- .../tunables/tunable_slice_references_test.py | 20 +--- .../tunable_to_configspace_distr_test.py | 9 +- .../tunables/tunable_to_configspace_test.py | 34 ++---- .../tests/tunables/tunables_assign_test.py | 89 +++++--------- .../tests/tunables/tunables_copy_test.py | 16 +-- .../tests/tunables/tunables_str_test.py | 12 +- mlos_bench/mlos_bench/tests/util_git_test.py | 8 +- .../mlos_bench/tests/util_nullable_test.py | 16 +-- .../mlos_bench/tests/util_try_parse_test.py | 8 +- mlos_bench/mlos_bench/tunables/__init__.py | 4 +- .../mlos_bench/tunables/covariant_group.py | 43 ++++--- mlos_bench/mlos_bench/tunables/tunable.py | 98 ++++++---------- .../mlos_bench/tunables/tunable_groups.py | 38 +++--- mlos_bench/mlos_bench/util.py | 21 ++-- mlos_bench/mlos_bench/version.py | 4 +- mlos_bench/setup.py | 4 +- mlos_core/mlos_core/__init__.py | 4 +- mlos_core/mlos_core/optimizers/__init__.py | 16 ++- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 12 +- .../bayesian_optimizers/smac_optimizer.py | 20 ++-- .../mlos_core/optimizers/flaml_optimizer.py | 21 ++-- mlos_core/mlos_core/optimizers/optimizer.py | 50 ++++---- .../mlos_core/optimizers/random_optimizer.py | 15 +-- mlos_core/mlos_core/spaces/__init__.py | 4 +- .../mlos_core/spaces/adapters/__init__.py | 10 +- .../mlos_core/spaces/adapters/adapter.py | 29 +++-- .../spaces/adapters/identity_adapter.py | 7 +- .../mlos_core/spaces/adapters/llamatune.py | 35 +++--- .../mlos_core/spaces/converters/__init__.py | 4 +- .../mlos_core/spaces/converters/flaml.py | 7 +- mlos_core/mlos_core/tests/__init__.py | 10 +- .../optimizers/bayesian_optimizers_test.py | 8 +- .../mlos_core/tests/optimizers/conftest.py | 8 +- .../tests/optimizers/one_hot_test.py | 44 +++---- .../optimizers/optimizer_multiobj_test.py | 12 +- .../tests/optimizers/optimizer_test.py | 32 ++--- .../tests/optimizers/random_optimizer_test.py | 4 +- mlos_core/mlos_core/tests/spaces/__init__.py | 4 +- .../spaces/adapters/identity_adapter_test.py | 8 +- .../tests/spaces/adapters/llamatune_test.py | 36 ++---- .../adapters/space_adapter_factory_test.py | 12 +- .../mlos_core/tests/spaces/spaces_test.py | 20 +--- mlos_core/mlos_core/util.py | 7 +- mlos_core/mlos_core/version.py | 4 +- mlos_core/setup.py | 4 +- mlos_viz/mlos_viz/__init__.py | 7 +- mlos_viz/mlos_viz/base.py | 11 +- mlos_viz/mlos_viz/dabl.py | 8 +- mlos_viz/mlos_viz/tests/__init__.py | 4 +- mlos_viz/mlos_viz/tests/conftest.py | 4 +- mlos_viz/mlos_viz/tests/test_base_plot.py | 4 +- mlos_viz/mlos_viz/tests/test_dabl_plot.py | 4 +- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 4 +- mlos_viz/mlos_viz/version.py | 4 +- mlos_viz/setup.py | 4 +- 246 files changed, 1303 insertions(+), 2526 deletions(-) diff --git a/mlos_bench/mlos_bench/__init__.py b/mlos_bench/mlos_bench/__init__.py index 1fed310b78..db8c235041 100644 --- a/mlos_bench/mlos_bench/__init__.py +++ b/mlos_bench/mlos_bench/__init__.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -mlos_bench is a framework to help automate benchmarking and and -OS/application parameter autotuning. +"""mlos_bench is a framework to help automate benchmarking and and OS/application +parameter autotuning. """ diff --git a/mlos_bench/mlos_bench/config/__init__.py b/mlos_bench/mlos_bench/config/__init__.py index 590e3d50d0..b78386118c 100644 --- a/mlos_bench/mlos_bench/config/__init__.py +++ b/mlos_bench/mlos_bench/config/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -mlos_bench.config -""" +"""mlos_bench.config.""" diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index c32dea9bf6..a6d2d31df6 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -3,9 +3,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Script for post-processing FIO results for mlos_bench. -""" +"""Script for post-processing FIO results for mlos_bench.""" import argparse import itertools @@ -16,9 +14,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: - """ - Flatten every dict in the hierarchy and rename the keys with the dict path. - """ + """Flatten every dict in the hierarchy and rename the keys with the dict path.""" if isinstance(data, dict): for (key, val) in data.items(): yield from _flat_dict(val, f"{path}.{key}") @@ -27,9 +23,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: def _main(input_file: str, output_file: str, prefix: str) -> None: - """ - Convert FIO read data from JSON to tall CSV. - """ + """Convert FIO read data from JSON to tall CSV.""" with open(input_file, mode='r', encoding='utf-8') as fh_input: json_data = json.load(fh_input) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index e33c717953..08cfe57faa 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -3,9 +3,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Script for post-processing redis-benchmark results. -""" +"""Script for post-processing redis-benchmark results.""" import argparse @@ -13,9 +11,7 @@ def _main(input_file: str, output_file: str) -> None: - """ - Re-shape Redis benchmark CSV results from wide to long. - """ + """Re-shape Redis benchmark CSV results from wide to long.""" df_wide = pd.read_csv(input_file) # Format the results from wide to long diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 41bd162459..47ed159c5a 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -6,8 +6,8 @@ """ Python script to parse through JSON and create new config file. -This script will be run in the SCHEDULER. -NEW_CFG will need to be copied over to the VM (/etc/default/grub.d). +This script will be run in the SCHEDULER. NEW_CFG will need to be copied over to the VM +(/etc/default/grub.d). """ import json diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index fa3b63e2e6..05756f59bf 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A module for managing config schemas and their validation. -""" +"""A module for managing config schemas and their validation.""" from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index 82cbcacce2..bfba5ed8a6 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A simple class for describing where to find different config schemas and validating configs against them. +"""A simple class for describing where to find different config schemas and validating +configs against them. """ import json # schema files are pure json - no comments @@ -35,8 +35,8 @@ # Note: we separate out the SchemaStore from a class method on ConfigSchema # because of issues with mypy/pylint and non-Enum-member class members. class SchemaStore(Mapping): - """ - A simple class for storing schemas and subschemas for the validator to reference. + """A simple class for storing schemas and subschemas for the validator to + reference. """ # A class member mapping of schema id to schema object. @@ -57,7 +57,9 @@ def __getitem__(self, key: str) -> dict: @classmethod def _load_schemas(cls) -> None: - """Loads all schemas and subschemas into the schema store for the validator to reference.""" + """Loads all schemas and subschemas into the schema store for the validator to + reference. + """ if cls._SCHEMA_STORE: return for root, _, files in walk(CONFIG_SCHEMA_DIR): @@ -77,7 +79,9 @@ def _load_schemas(cls) -> None: @classmethod def _load_registry(cls) -> None: - """Also store them in a Registry object for referencing by recent versions of jsonschema.""" + """Also store them in a Registry object for referencing by recent versions of + jsonschema. + """ if not cls._SCHEMA_STORE: cls._load_schemas() cls._REGISTRY = Registry().with_resources([ @@ -97,9 +101,7 @@ def registry(self) -> Registry: class ConfigSchema(Enum): - """ - An enum to help describe schema types and help validate configs against them. - """ + """An enum to help describe schema types and help validate configs against them.""" CLI = path_join(CONFIG_SCHEMA_DIR, "cli/cli-schema.json") GLOBALS = path_join(CONFIG_SCHEMA_DIR, "cli/globals-schema.json") diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 4ccef7817b..2243bec7a4 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Simple class to help with nested dictionary $var templating. -""" +"""Simple class to help with nested dictionary $var templating.""" from copy import deepcopy from string import Template @@ -14,9 +12,7 @@ class DictTemplater: # pylint: disable=too-few-public-methods - """ - Simple class to help with nested dictionary $var templating. - """ + """Simple class to help with nested dictionary $var templating.""" def __init__(self, source_dict: Dict[str, Any]): """ @@ -56,9 +52,7 @@ def expand_vars(self, *, return self._dict def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: - """ - Recursively expand $var strings in the currently operating dictionary. - """ + """Recursively expand $var strings in the currently operating dictionary.""" if isinstance(value, str): # First try to expand all $vars internally. value = Template(value).safe_substitute(self._dict) diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index a1ccadae5f..8a4df5a5b2 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tunable Environments for mlos_bench. -""" +"""Tunable Environments for mlos_bench.""" from mlos_bench.environments.base_environment import Environment from mlos_bench.environments.composite_env import CompositeEnv @@ -16,7 +14,6 @@ __all__ = [ 'Status', - 'Environment', 'MockEnv', 'RemoteEnv', diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 61fbd69f50..d91bb57041 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A hierarchy of benchmark environments. -""" +"""A hierarchy of benchmark environments.""" import abc import json @@ -43,9 +41,7 @@ class Environment(metaclass=abc.ABCMeta): # pylint: disable=too-many-instance-attributes - """ - An abstract base of all benchmark environments. - """ + """An abstract base of all benchmark environments.""" @classmethod def new(cls, @@ -164,10 +160,9 @@ def __init__(self, name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: - """ - Reconstructs a basic json config that this class might have been - instantiated from in order to validate configs provided outside the - file loading mechanism. + """Reconstructs a basic json config that this class might have been instantiated + from in order to validate configs provided outside the file loading + mechanism. """ json_config: dict = { "class": self.__class__.__module__ + "." + self.__class__.__name__, @@ -211,9 +206,7 @@ def _expand_groups(groups: Iterable[str], @staticmethod def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: - """ - Expand `$var` into actual values of the variables. - """ + """Expand `$var` into actual values of the variables.""" return DictTemplater(params).expand_vars(extra_source_dict=global_config) @property @@ -222,9 +215,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": return self._service.config_loader_service def __enter__(self) -> 'Environment': - """ - Enter the environment's benchmarking context. - """ + """Enter the environment's benchmarking context.""" _LOG.debug("Environment START :: %s", self) assert not self._in_context if self._service: @@ -235,9 +226,7 @@ def __enter__(self) -> 'Environment': def __exit__(self, ex_type: Optional[Type[BaseException]], ex_val: Optional[BaseException], ex_tb: Optional[TracebackType]) -> Literal[False]: - """ - Exit the context of the benchmarking environment. - """ + """Exit the context of the benchmarking environment.""" ex_throw = None if ex_val is None: _LOG.debug("Environment END :: %s", self) @@ -267,8 +256,8 @@ def __repr__(self) -> str: def pprint(self, indent: int = 4, level: int = 0) -> str: """ - Pretty-print the environment configuration. - For composite environments, print all children environments as well. + Pretty-print the environment configuration. For composite environments, print + all children environments as well. Parameters ---------- @@ -288,8 +277,8 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ Plug tunable values into the base config. If the tunable group is unknown, - ignore it (it might belong to another environment). This method should - never mutate the original config or the tunables. + ignore it (it might belong to another environment). This method should never + mutate the original config or the tunables. Parameters ---------- @@ -321,8 +310,9 @@ def tunable_params(self) -> TunableGroups: @property def parameters(self) -> Dict[str, TunableValue]: """ - Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`). - Note that before `.setup()` is called, all tunables will be set to None. + Key/value pairs of all environment parameters (i.e., `const_args` and + `tunable_params`). Note that before `.setup()` is called, all tunables will be + set to None. Returns ------- @@ -334,8 +324,8 @@ def parameters(self) -> Dict[str, TunableValue]: def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ Set up a new benchmark environment, if necessary. This method must be - idempotent, i.e., calling it several times in a row should be - equivalent to a single call. + idempotent, i.e., calling it several times in a row should be equivalent to a + single call. Parameters ---------- @@ -382,9 +372,10 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - def teardown(self) -> None: """ - Tear down the benchmark environment. This method must be idempotent, - i.e., calling it several times in a row should be equivalent to a - single call. + Tear down the benchmark environment. + + This method must be idempotent, i.e., calling it several times in a row should + be equivalent to a single call. """ _LOG.info("Teardown %s", self) # Make sure we create a context before invoking setup/run/status/teardown diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index a71b8ab9be..4bf38a5ef2 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Composite benchmark environment. -""" +"""Composite benchmark environment.""" import logging from datetime import datetime @@ -23,9 +21,7 @@ class CompositeEnv(Environment): - """ - Composite benchmark environment. - """ + """Composite benchmark environment.""" def __init__(self, *, @@ -111,9 +107,7 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], @property def children(self) -> List[Environment]: - """ - Return the list of child environments. - """ + """Return the list of child environments.""" return self._children def pprint(self, indent: int = 4, level: int = 0) -> str: @@ -138,6 +132,7 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ Add a new child environment to the composite environment. + This method is called from the constructor only. """ _LOG.debug("Merge tunables: '%s' <- '%s' :: %s", self, env, env.tunable_params) @@ -170,9 +165,10 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - def teardown(self) -> None: """ - Tear down the children environments. This method is idempotent, - i.e., calling it several times is equivalent to a single call. - The environments are being torn down in the reverse order. + Tear down the children environments. + + This method is idempotent, i.e., calling it several times is equivalent to a + single call. The environments are being torn down in the reverse order. """ assert self._in_context for env_context in reversed(self._child_contexts): @@ -181,9 +177,9 @@ def teardown(self) -> None: def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: """ - Submit a new experiment to the environment. - Return the result of the *last* child environment if successful, - or the status of the last failed environment otherwise. + Submit a new experiment to the environment. Return the result of the *last* + child environment if successful, or the status of the last failed environment + otherwise. Returns ------- diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index 0cdd8349b4..9a51941529 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Local Environments for mlos_bench. -""" +"""Local Environments for mlos_bench.""" from mlos_bench.environments.local.local_env import LocalEnv from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index da20f5c961..8cb877a9d0 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Scheduler-side benchmark environment to run scripts locally. -""" +"""Scheduler-side benchmark environment to run scripts locally.""" import json import logging @@ -32,9 +30,7 @@ class LocalEnv(ScriptEnv): # pylint: disable=too-many-instance-attributes - """ - Scheduler-side Environment that runs scripts locally. - """ + """Scheduler-side Environment that runs scripts locally.""" def __init__(self, *, @@ -90,9 +86,7 @@ def __enter__(self) -> Environment: def __exit__(self, ex_type: Optional[Type[BaseException]], ex_val: Optional[BaseException], ex_tb: Optional[TracebackType]) -> Literal[False]: - """ - Exit the context of the benchmarking environment. - """ + """Exit the context of the benchmarking environment.""" assert not (self._temp_dir is None or self._temp_dir_context is None) self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb) self._temp_dir = None @@ -101,8 +95,8 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ - Check if the environment is ready and set up the application - and benchmarks, if necessary. + Check if the environment is ready and set up the application and benchmarks, if + necessary. Parameters ---------- @@ -203,9 +197,7 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: @staticmethod def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: - """ - Strip trailing spaces from column names (Windows only). - """ + """Strip trailing spaces from column names (Windows only).""" # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. @@ -254,9 +246,7 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: ]) def teardown(self) -> None: - """ - Clean up the local environment. - """ + """Clean up the local environment.""" if self._script_teardown: _LOG.info("Local teardown: %s", self) (return_code, _output) = self._local_exec(self._script_teardown) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 174afd387c..fd6c2c1127 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -2,9 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Scheduler-side Environment to run scripts locally -and upload/download data to the shared storage. +"""Scheduler-side Environment to run scripts locally and upload/download data to the +shared storage. """ import logging @@ -24,9 +23,8 @@ class LocalFileShareEnv(LocalEnv): - """ - Scheduler-side Environment that runs scripts locally - and uploads/downloads data to the shared file storage. + """Scheduler-side Environment that runs scripts locally and uploads/downloads data + to the shared file storage. """ def __init__(self, @@ -73,9 +71,8 @@ def __init__(self, self._download = self._template_from_to("download") def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: - """ - Convert a list of {"from": "...", "to": "..."} to a list of pairs - of string.Template objects so that we can plug in self._params into it later. + """Convert a list of {"from": "...", "to": "..."} to a list of pairs of + string.Template objects so that we can plug in self._params into it later. """ return [ (Template(d['from']), Template(d['to'])) @@ -87,6 +84,7 @@ def _expand(from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. + Return a generator of (str, str) pairs of paths. """ return ( @@ -152,8 +150,8 @@ def _download_files(self, ignore_missing: bool = False) -> None: def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: """ - Download benchmark results from the shared storage - and run post-processing scripts locally. + Download benchmark results from the shared storage and run post-processing + scripts locally. Returns ------- diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index cc47b95500..16ff1195de 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Scheduler-side environment to mock the benchmark results. -""" +"""Scheduler-side environment to mock the benchmark results.""" import logging import random @@ -22,9 +20,7 @@ class MockEnv(Environment): - """ - Scheduler-side environment to mock the benchmark results. - """ + """Scheduler-side environment to mock the benchmark results.""" _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" @@ -97,6 +93,7 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: def _normalized(tunable: Tunable) -> float: """ Get the NORMALIZED value of a tunable. + That is, map current value to the [0, 1] range. """ val = None diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index f07575ac86..3b26f8d6a7 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Remote Tunable Environments for mlos_bench. -""" +"""Remote Tunable Environments for mlos_bench.""" from mlos_bench.environments.remote.host_env import HostEnv from mlos_bench.environments.remote.network_env import NetworkEnv diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index 05896c9e60..ae88fa2197 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Remote host Environment. -""" +"""Remote host Environment.""" import logging from typing import Optional @@ -18,9 +16,7 @@ class HostEnv(Environment): - """ - Remote host environment. - """ + """Remote host environment.""" def __init__(self, *, @@ -87,9 +83,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return self._is_ready def teardown(self) -> None: - """ - Shut down the Host and release it. - """ + """Shut down the Host and release it.""" _LOG.info("Host tear down: %s", self) (status, params) = self._host_service.deprovision_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index 552f1729d9..c3ad8ccd82 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Network Environment. -""" +"""Network Environment.""" import logging from typing import Optional @@ -97,9 +95,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return self._is_ready def teardown(self) -> None: - """ - Shut down the Network and releases it. - """ + """Shut down the Network and releases it.""" if not self._deprovision_on_teardown: _LOG.info("Skipping Network deprovision: %s", self) return diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index ef733c77c2..68a6f5fbe7 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -OS-level remote Environment on Azure. -""" +"""OS-level remote Environment on Azure.""" import logging from typing import Optional @@ -20,9 +18,7 @@ class OSEnv(Environment): - """ - OS Level Environment for a host. - """ + """OS Level Environment for a host.""" def __init__(self, *, @@ -97,9 +93,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return self._is_ready def teardown(self) -> None: - """ - Clean up and shut down the host without deprovisioning it. - """ + """Clean up and shut down the host without deprovisioning it.""" _LOG.info("OS tear down: %s", self) (status, params) = self._os_service.shutdown(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index cf38a57b01..94e789b198 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -77,8 +77,8 @@ def __init__(self, def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: """ - Check if the environment is ready and set up the application - and benchmarks on a remote host. + Check if the environment is ready and set up the application and benchmarks on a + remote host. Parameters ---------- @@ -143,9 +143,7 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return (status, timestamp, output) def teardown(self) -> None: - """ - Clean up and shut down the remote environment. - """ + """Clean up and shut down the remote environment.""" if self._script_teardown: _LOG.info("Remote teardown: %s", self) (status, _timestamp, _output) = self._remote_exec(self._script_teardown) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index b661bfad7e..024430e22a 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Cloud-based (configurable) SaaS environment. -""" +"""Cloud-based (configurable) SaaS environment.""" import logging from typing import Optional @@ -19,9 +17,7 @@ class SaaSEnv(Environment): - """ - Cloud-based (configurable) SaaS environment. - """ + """Cloud-based (configurable) SaaS environment.""" def __init__(self, *, diff --git a/mlos_bench/mlos_bench/environments/remote/vm_env.py b/mlos_bench/mlos_bench/environments/remote/vm_env.py index eae7bf982c..3be95ce2c2 100644 --- a/mlos_bench/mlos_bench/environments/remote/vm_env.py +++ b/mlos_bench/mlos_bench/environments/remote/vm_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -"Remote" VM (Host) Environment. -""" +"""Remote VM (Host) Environment.""" import logging diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 129ac21a0f..745430ca69 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base scriptable benchmark environment. -""" +"""Base scriptable benchmark environment.""" import abc import logging @@ -21,9 +19,7 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): - """ - Base Environment that runs scripts for setup/run/teardown. - """ + """Base Environment that runs scripts for setup/run/teardown.""" _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") diff --git a/mlos_bench/mlos_bench/environments/status.py b/mlos_bench/mlos_bench/environments/status.py index fbe3dcccf4..f3e0d0ea37 100644 --- a/mlos_bench/mlos_bench/environments/status.py +++ b/mlos_bench/mlos_bench/environments/status.py @@ -2,17 +2,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Enum for the status of the benchmark/environment. -""" +"""Enum for the status of the benchmark/environment.""" import enum class Status(enum.Enum): - """ - Enum for the status of the benchmark/environment. - """ + """Enum for the status of the benchmark/environment.""" UNKNOWN = 0 PENDING = 1 @@ -24,9 +20,7 @@ class Status(enum.Enum): TIMED_OUT = 7 def is_good(self) -> bool: - """ - Check if the status of the benchmark/environment is good. - """ + """Check if the status of the benchmark/environment is good.""" return self in { Status.PENDING, Status.READY, @@ -35,9 +29,8 @@ def is_good(self) -> bool: } def is_completed(self) -> bool: - """ - Check if the status of the benchmark/environment is - one of {SUCCEEDED, CANCELED, FAILED, TIMED_OUT}. + """Check if the status of the benchmark/environment is one of {SUCCEEDED, + CANCELED, FAILED, TIMED_OUT}. """ return self in { Status.SUCCEEDED, @@ -47,37 +40,25 @@ def is_completed(self) -> bool: } def is_pending(self) -> bool: - """ - Check if the status of the benchmark/environment is PENDING. - """ + """Check if the status of the benchmark/environment is PENDING.""" return self == Status.PENDING def is_ready(self) -> bool: - """ - Check if the status of the benchmark/environment is READY. - """ + """Check if the status of the benchmark/environment is READY.""" return self == Status.READY def is_succeeded(self) -> bool: - """ - Check if the status of the benchmark/environment is SUCCEEDED. - """ + """Check if the status of the benchmark/environment is SUCCEEDED.""" return self == Status.SUCCEEDED def is_failed(self) -> bool: - """ - Check if the status of the benchmark/environment is FAILED. - """ + """Check if the status of the benchmark/environment is FAILED.""" return self == Status.FAILED def is_canceled(self) -> bool: - """ - Check if the status of the benchmark/environment is CANCELED. - """ + """Check if the status of the benchmark/environment is CANCELED.""" return self == Status.CANCELED def is_timed_out(self) -> bool: - """ - Check if the status of the benchmark/environment is TIMED_OUT. - """ + """Check if the status of the benchmark/environment is TIMED_OUT.""" return self == Status.FAILED diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index 4555ab7f50..8684844063 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -EventLoopContext class definition. -""" +"""EventLoopContext class definition.""" import asyncio import logging @@ -31,15 +29,15 @@ class EventLoopContext: """ - EventLoopContext encapsulates a background thread for asyncio event - loop processing as an aid for context managers. + EventLoopContext encapsulates a background thread for asyncio event loop processing + as an aid for context managers. - There is generally only expected to be one of these, either as a base - class instance if it's specific to that functionality or for the full - mlos_bench process to support parallel trial runners, for instance. + There is generally only expected to be one of these, either as a base class instance + if it's specific to that functionality or for the full mlos_bench process to support + parallel trial runners, for instance. - It's enter() and exit() routines are expected to be called from the - caller's context manager routines (e.g., __enter__ and __exit__). + It's enter() and exit() routines are expected to be called from the caller's context + manager routines (e.g., __enter__ and __exit__). """ def __init__(self) -> None: @@ -49,17 +47,13 @@ def __init__(self) -> None: self._event_loop_thread_refcnt: int = 0 def _run_event_loop(self) -> None: - """ - Runs the asyncio event loop in a background thread. - """ + """Runs the asyncio event loop in a background thread.""" assert self._event_loop is not None asyncio.set_event_loop(self._event_loop) self._event_loop.run_forever() def enter(self) -> None: - """ - Manages starting the background thread for event loop processing. - """ + """Manages starting the background thread for event loop processing.""" # Start the background thread if it's not already running. with self._event_loop_thread_lock: if not self._event_loop_thread: @@ -74,9 +68,7 @@ def enter(self) -> None: self._event_loop_thread_refcnt += 1 def exit(self) -> None: - """ - Manages cleaning up the background thread for event loop processing. - """ + """Manages cleaning up the background thread for event loop processing.""" with self._event_loop_thread_lock: self._event_loop_thread_refcnt -= 1 assert self._event_loop_thread_refcnt >= 0 @@ -92,8 +84,8 @@ def exit(self) -> None: def run_coroutine(self, coro: Coroutine[Any, Any, CoroReturnType]) -> FutureReturnType: """ - Runs the given coroutine in the background event loop thread and - returns a Future that can be used to wait for the result. + Runs the given coroutine in the background event loop thread and returns a + Future that can be used to wait for the result. Parameters ---------- diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c8e48dab69..1a0caa6bba 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # """ -A helper class to load the configuration files, parse the command line parameters, -and instantiate the main components of mlos_bench system. +A helper class to load the configuration files, parse the command line parameters, and +instantiate the main components of mlos_bench system. It is used in `mlos_bench.run` module to run the benchmark/optimizer from the command line. @@ -40,9 +40,7 @@ class Launcher: # pylint: disable=too-few-public-methods,too-many-instance-attributes - """ - Command line launcher for mlos_bench and mlos_core. - """ + """Command line launcher for mlos_bench and mlos_core.""" def __init__(self, description: str, long_text: str = "", argv: Optional[List[str]] = None): # pylint: disable=too-many-statements @@ -143,23 +141,17 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st @property def config_loader(self) -> ConfigPersistenceService: - """ - Get the config loader service. - """ + """Get the config loader service.""" return self._config_loader @property def service(self) -> Service: - """ - Get the parent service. - """ + """Get the parent service.""" return self._parent_service @staticmethod def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: - """ - Parse the command line arguments. - """ + """Parse the command line arguments.""" parser.add_argument( '--config', required=False, help='Main JSON5 configuration file. Its keys are the same as the' + @@ -259,9 +251,7 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T @staticmethod def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: - """ - Helper function to parse global key/value pairs from the command line. - """ + """Helper function to parse global key/value pairs from the command line.""" _LOG.debug("Extra args: %s", cmdline) config: Dict[str, TunableValue] = {} @@ -293,9 +283,8 @@ def _load_config(self, config_path: Iterable[str], args_rest: Iterable[str], global_config: Dict[str, Any]) -> Dict[str, Any]: - """ - Get key/value pairs of the global configuration parameters - from the specified config files (if any) and command line arguments. + """Get key/value pairs of the global configuration parameters from the specified + config files (if any) and command line arguments. """ for config_file in (args_globals or []): conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) @@ -308,9 +297,8 @@ def _load_config(self, def _init_tunable_values(self, random_init: bool, seed: Optional[int], args_tunables: Optional[str]) -> TunableGroups: - """ - Initialize the tunables and load key/value pairs of the tunable values - from given JSON files, if specified. + """Initialize the tunables and load key/value pairs of the tunable values from + given JSON files, if specified. """ tunables = self.environment.tunable_params _LOG.debug("Init tunables: default = %s", tunables) @@ -332,9 +320,11 @@ def _init_tunable_values(self, random_init: bool, seed: Optional[int], def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: """ - Instantiate the Optimizer object from JSON config file, if specified - in the --optimizer command line option. If config file not specified, - create a one-shot optimizer to run a single benchmark trial. + Instantiate the Optimizer object from JSON config file, if specified in the + --optimizer command line option. + + If config file not specified, create a one-shot optimizer to run a single + benchmark trial. """ if args_optimizer is None: # global_config may contain additional properties, so we need to @@ -352,9 +342,10 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: def _load_storage(self, args_storage: Optional[str]) -> Storage: """ - Instantiate the Storage object from JSON file provided in the --storage - command line parameter. If omitted, create an ephemeral in-memory SQL - storage instead. + Instantiate the Storage object from JSON file provided in the --storage command + line parameter. + + If omitted, create an ephemeral in-memory SQL storage instead. """ if args_storage is None: # pylint: disable=import-outside-toplevel @@ -376,6 +367,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: """ Instantiate the Scheduler object from JSON file provided in the --scheduler command line parameter. + Create a simple synchronous single-threaded scheduler if omitted. """ # Set `teardown` for scheduler only to prevent conflicts with other configs. diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index f10fa3c82e..167fe022e6 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Interfaces and wrapper classes for optimizers to be used in Autotune. -""" +"""Interfaces and wrapper classes for optimizers to be used in Autotune.""" from mlos_bench.optimizers.base_optimizer import Optimizer from mlos_bench.optimizers.mlos_core_optimizer import MlosCoreOptimizer diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index b9df1db1b7..e9b4ff8388 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -2,9 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base class for an interface between the benchmarking framework -and mlos_core optimizers. +"""Base class for an interface between the benchmarking framework and mlos_core +optimizers. """ import logging @@ -27,8 +26,8 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes - """ - An abstract interface between the benchmarking framework and mlos_core optimizers. + """An abstract interface between the benchmarking framework and mlos_core + optimizers. """ # See Also: mlos_bench/mlos_bench/config/schemas/optimizers/optimizer-schema.json @@ -45,7 +44,8 @@ def __init__(self, global_config: Optional[dict] = None, service: Optional[Service] = None): """ - Create a new optimizer for the given configuration space defined by the tunables. + Create a new optimizer for the given configuration space defined by the + tunables. Parameters ---------- @@ -88,10 +88,9 @@ def __init__(self, raise ValueError(f"Invalid optimization direction: {opt_dir} for {opt_target}") def _validate_json_config(self, config: dict) -> None: - """ - Reconstructs a basic json config that this class might have been - instantiated from in order to validate configs provided outside the - file loading mechanism. + """Reconstructs a basic json config that this class might have been instantiated + from in order to validate configs provided outside the file loading + mechanism. """ json_config: dict = { "class": self.__class__.__module__ + "." + self.__class__.__name__, @@ -108,9 +107,7 @@ def __repr__(self) -> str: return f"{self.name}({opt_targets},config={self._config})" def __enter__(self) -> 'Optimizer': - """ - Enter the optimizer's context. - """ + """Enter the optimizer's context.""" _LOG.debug("Optimizer START :: %s", self) assert not self._in_context self._in_context = True @@ -119,9 +116,7 @@ def __enter__(self) -> 'Optimizer': def __exit__(self, ex_type: Optional[Type[BaseException]], ex_val: Optional[BaseException], ex_tb: Optional[TracebackType]) -> Literal[False]: - """ - Exit the context of the optimizer. - """ + """Exit the context of the optimizer.""" if ex_val is None: _LOG.debug("Optimizer END :: %s", self) else: @@ -153,15 +148,14 @@ def max_iterations(self) -> int: @property def seed(self) -> int: - """ - The random seed for the optimizer. - """ + """The random seed for the optimizer.""" return self._seed @property def start_with_defaults(self) -> bool: """ Return True if the optimizer should start with the default values. + Note: This parameter is mutable and will be reset to False after the defaults are first suggested. """ @@ -197,16 +191,16 @@ def config_space(self) -> ConfigurationSpace: @property def name(self) -> str: """ - The name of the optimizer. We save this information in - mlos_bench storage to track the source of each configuration. + The name of the optimizer. + + We save this information in mlos_bench storage to track the source of each + configuration. """ return self.__class__.__name__ @property def targets(self) -> Dict[str, Literal['min', 'max']]: - """ - A dictionary of {target: direction} of optimization targets. - """ + """A dictionary of {target: direction} of optimization targets.""" return { opt_target: "min" if opt_dir == 1 else "max" for (opt_target, opt_dir) in self._opt_targets.items() @@ -214,8 +208,8 @@ def targets(self) -> Dict[str, Literal['min', 'max']]: @property def supports_preload(self) -> bool: - """ - Return True if the optimizer supports pre-loading the data from previous experiments. + """Return True if the optimizer supports pre-loading the data from previous + experiments. """ return True @@ -255,9 +249,8 @@ def bulk_register(self, def suggest(self) -> TunableGroups: """ - Generate the next suggestion. - Base class' implementation increments the iteration count - and returns the current values of the tunables. + Generate the next suggestion. Base class' implementation increments the + iteration count and returns the current values of the tunables. Returns ------- @@ -303,8 +296,8 @@ def _get_scores(self, status: Status, scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] ) -> Optional[Dict[str, float]]: """ - Extract a scalar benchmark score from the dataframe. - Change the sign if we are maximizing. + Extract a scalar benchmark score from the dataframe. Change the sign if we are + maximizing. Parameters ---------- @@ -340,6 +333,7 @@ def _get_scores(self, status: Status, def not_converged(self) -> bool: """ Return True if not converged, False otherwise. + Base implementation just checks the iteration count. """ return self._iter < self._max_iter diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 62341c613d..3ab1c43ab9 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Functions to convert TunableGroups to ConfigSpace for use with the mlos_core optimizers. +"""Functions to convert TunableGroups to ConfigSpace for use with the mlos_core +optimizers. """ import logging @@ -31,6 +31,7 @@ class TunableValueKind: """ Enum for the kind of the tunable value (special or not). + It is not a true enum because ConfigSpace wants string values. """ @@ -40,9 +41,7 @@ class TunableValueKind: def _normalize_weights(weights: List[float]) -> List[float]: - """ - Helper function for normalizing weights to probabilities. - """ + """Helper function for normalizing weights to probabilities.""" total = sum(weights) return [w / total for w in weights] @@ -219,6 +218,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: """ Remove the fields that correspond to special values in ConfigSpace. + In particular, remove and keys suffixes added by `special_param_names`. """ data = data.copy() @@ -240,8 +240,8 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: def special_param_names(name: str) -> Tuple[str, str]: """ - Generate the names of the auxiliary hyperparameters that correspond - to a tunable that can have special values. + Generate the names of the auxiliary hyperparameters that correspond to a tunable + that can have special values. NOTE: `!` characters are currently disallowed in Tunable names in order handle this logic. diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 4f207f5fc9..9d90a58560 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Grid search optimizer for mlos_bench. -""" +"""Grid search optimizer for mlos_bench.""" import logging from typing import Dict, Iterable, Optional, Sequence, Set, Tuple @@ -24,9 +22,7 @@ class GridSearchOptimizer(TrackBestOptimizer): - """ - Grid search optimizer. - """ + """Grid search optimizer.""" def __init__(self, tunables: TunableGroups, @@ -120,9 +116,7 @@ def bulk_register(self, return True def suggest(self) -> TunableGroups: - """ - Generate the next grid search suggestion. - """ + """Generate the next grid search suggestion.""" tunables = super().suggest() if self._start_with_defaults: _LOG.info("Use default values for the first trial") diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index d7d50f1ca5..e9a522a683 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A wrapper for mlos_core optimizers for mlos_bench. -""" +"""A wrapper for mlos_core optimizers for mlos_bench.""" import logging import os @@ -36,9 +34,7 @@ class MlosCoreOptimizer(Optimizer): - """ - A wrapper class for the mlos_core optimizers. - """ + """A wrapper class for the mlos_core optimizers.""" def __init__(self, tunables: TunableGroups, @@ -127,17 +123,15 @@ def bulk_register(self, return True def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: - """ - In-place adjust the signs of the scores for MINIMIZATION problem. - """ + """In-place adjust the signs of the scores for MINIMIZATION problem.""" for (opt_target, opt_dir) in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: """ - Select from past trials only the columns required in this experiment and - impute default values for the tunables that are missing in the dataframe. + Select from past trials only the columns required in this experiment and impute + default values for the tunables that are missing in the dataframe. Parameters ---------- diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index ada4411b58..2d70512b1f 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Mock optimizer for mlos_bench. -""" +"""Mock optimizer for mlos_bench.""" import logging import random @@ -20,9 +18,7 @@ class MockOptimizer(TrackBestOptimizer): - """ - Mock optimizer to test the Environment API. - """ + """Mock optimizer to test the Environment API.""" def __init__(self, tunables: TunableGroups, @@ -54,9 +50,7 @@ def bulk_register(self, return True def suggest(self) -> TunableGroups: - """ - Generate the next (random) suggestion. - """ + """Generate the next (random) suggestion.""" tunables = super().suggest() if self._start_with_defaults: _LOG.info("Use default tunable values") diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index 9ad1070c46..d0c0e531ef 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -No-op optimizer for mlos_bench that proposes a single configuration. -""" +"""No-op optimizer for mlos_bench that proposes a single configuration.""" import logging from typing import Optional @@ -19,6 +17,7 @@ class OneShotOptimizer(MockOptimizer): """ Mock optimizer that proposes a single configuration and returns. + Explicit configs (partial or full) are possible using configuration files. """ diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index 32a23142e3..e90f81a6ea 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Mock optimizer for mlos_bench. -""" +"""Mock optimizer for mlos_bench.""" import logging from abc import ABCMeta @@ -20,9 +18,7 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): - """ - Base Optimizer class that keeps track of the best score and configuration. - """ + """Base Optimizer class that keeps track of the best score and configuration.""" def __init__(self, tunables: TunableGroups, @@ -42,9 +38,7 @@ def register(self, tunables: TunableGroups, status: Status, return registered_score def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: - """ - Compare the optimization scores to the best ones so far lexicographically. - """ + """Compare the optimization scores to the best ones so far lexicographically.""" if self._best_score is None: return True assert registered_score is not None diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index a7912688a1..44c15eb709 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -3,8 +3,8 @@ # Licensed under the MIT License. # """ -Simple platform agnostic abstraction for the OS environment variables. -Meant as a replacement for os.environ vs nt.environ. +Simple platform agnostic abstraction for the OS environment variables. Meant as a +replacement for os.environ vs nt.environ. Example ------- diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index c54e3c0efc..a269560b73 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Interfaces and implementations of the optimization loop scheduling policies. -""" +"""Interfaces and implementations of the optimization loop scheduling policies.""" from mlos_bench.schedulers.base_scheduler import Scheduler from mlos_bench.schedulers.sync_scheduler import SyncScheduler diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 0b6733e423..b2a7328ebb 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base class for the optimization loop scheduling policies. -""" +"""Base class for the optimization loop scheduling policies.""" import json import logging @@ -27,9 +25,7 @@ class Scheduler(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes - """ - Base class for the optimization loop scheduling policies. - """ + """Base class for the optimization loop scheduling policies.""" def __init__(self, *, config: Dict[str, Any], @@ -39,10 +35,10 @@ def __init__(self, *, storage: Storage, root_env_config: str): """ - Create a new instance of the scheduler. The constructor of this - and the derived classes is called by the persistence service - after reading the class JSON configuration. Other objects like - the Environment and Optimizer are provided by the Launcher. + Create a new instance of the scheduler. The constructor of this and the derived + classes is called by the persistence service after reading the class JSON + configuration. Other objects like the Environment and Optimizer are provided by + the Launcher. Parameters ---------- @@ -96,9 +92,7 @@ def __repr__(self) -> str: return self.__class__.__name__ def __enter__(self) -> 'Scheduler': - """ - Enter the scheduler's context. - """ + """Enter the scheduler's context.""" _LOG.debug("Scheduler START :: %s", self) assert self.experiment is None self.environment.__enter__() @@ -121,9 +115,7 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], ex_val: Optional[BaseException], ex_tb: Optional[TracebackType]) -> Literal[False]: - """ - Exit the context of the scheduler. - """ + """Exit the context of the scheduler.""" if ex_val is None: _LOG.debug("Scheduler END :: %s", self) else: @@ -138,9 +130,7 @@ def __exit__(self, @abstractmethod def start(self) -> None: - """ - Start the optimization loop. - """ + """Start the optimization loop.""" assert self.experiment is not None _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", self.experiment, self.environment, self.optimizer) @@ -154,6 +144,7 @@ def start(self) -> None: def teardown(self) -> None: """ Tear down the environment. + Call it after the completion of the `.start()` in the scheduler context. """ assert self.experiment is not None @@ -161,17 +152,13 @@ def teardown(self) -> None: self.environment.teardown() def get_best_observation(self) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: - """ - Get the best observation from the optimizer. - """ + """Get the best observation from the optimizer.""" (best_score, best_config) = self.optimizer.get_best_observation() _LOG.info("Env: %s best score: %s", self.environment, best_score) return (best_score, best_config) def load_config(self, config_id: int) -> TunableGroups: - """ - Load the existing tunable configuration from the storage. - """ + """Load the existing tunable configuration from the storage.""" assert self.experiment is not None tunable_values = self.experiment.load_tunable_config(config_id) tunables = self.environment.tunable_params.assign(tunable_values) @@ -182,9 +169,11 @@ def load_config(self, config_id: int) -> TunableGroups: def _schedule_new_optimizer_suggestions(self) -> bool: """ - Optimizer part of the loop. Load the results of the executed trials - into the optimizer, suggest new configurations, and add them to the queue. - Return True if optimization is not over, False otherwise. + Optimizer part of the loop. + + Load the results of the executed trials into the optimizer, suggest new + configurations, and add them to the queue. Return True if optimization is not + over, False otherwise. """ assert self.experiment is not None (trial_ids, configs, scores, status) = self.experiment.load(self._last_trial_id) @@ -200,9 +189,7 @@ def _schedule_new_optimizer_suggestions(self) -> bool: return not_done def schedule_trial(self, tunables: TunableGroups) -> None: - """ - Add a configuration to the queue of trials. - """ + """Add a configuration to the queue of trials.""" for repeat_i in range(1, self._trial_config_repeat_count + 1): self._add_trial_to_queue(tunables, config={ # Add some additional metadata to track for the trial such as the @@ -227,6 +214,7 @@ def _add_trial_to_queue(self, tunables: TunableGroups, config: Optional[Dict[str, Any]] = None) -> None: """ Add a configuration to the queue of trials. + A wrapper for the `Experiment.new_trial` method. """ assert self.experiment is not None @@ -235,7 +223,9 @@ def _add_trial_to_queue(self, tunables: TunableGroups, def _run_schedule(self, running: bool = False) -> None: """ - Scheduler part of the loop. Check for pending trials in the queue and run them. + Scheduler part of the loop. + + Check for pending trials in the queue and run them. """ assert self.experiment is not None for trial in self.experiment.pending_trials(datetime.now(UTC), running=running): @@ -244,6 +234,7 @@ def _run_schedule(self, running: bool = False) -> None: def not_done(self) -> bool: """ Check the stopping conditions. + By default, stop when the optimizer converges or max limit of trials reached. """ return self.optimizer.not_converged() and ( @@ -253,7 +244,9 @@ def not_done(self) -> bool: @abstractmethod def run_trial(self, trial: Storage.Trial) -> None: """ - Set up and run a single trial. Save the results in the storage. + Set up and run a single trial. + + Save the results in the storage. """ assert self.experiment is not None self._trial_count += 1 diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index a73a493533..0d3cfa0969 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A simple single-threaded synchronous optimization loop implementation. -""" +"""A simple single-threaded synchronous optimization loop implementation.""" import logging from datetime import datetime @@ -19,14 +17,10 @@ class SyncScheduler(Scheduler): - """ - A simple single-threaded synchronous optimization loop implementation. - """ + """A simple single-threaded synchronous optimization loop implementation.""" def start(self) -> None: - """ - Start the optimization loop. - """ + """Start the optimization loop.""" super().start() is_warm_up = self.optimizer.supports_preload @@ -42,7 +36,9 @@ def start(self) -> None: def run_trial(self, trial: Storage.Trial) -> None: """ - Set up and run a single trial. Save the results in the storage. + Set up and run a single trial. + + Save the results in the storage. """ super().run_trial(trial) diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index bcc7d02d6f..b9b0b51693 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Services for implementing Environments for mlos_bench. -""" +"""Services for implementing Environments for mlos_bench.""" from mlos_bench.services.base_fileshare import FileShareService from mlos_bench.services.base_service import Service diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index f00a7a1a00..07da98d11f 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base class for remote file shares. -""" +"""Base class for remote file shares.""" import logging from abc import ABCMeta, abstractmethod @@ -17,9 +15,7 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): - """ - An abstract base of all file shares. - """ + """An abstract base of all file shares.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index e7c9365bf7..5b8a93fee6 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base class for the service mix-ins. -""" +"""Base class for the service mix-ins.""" import json import logging @@ -21,9 +19,7 @@ class Service: - """ - An abstract base of all Environment Services and used to build up mix-ins. - """ + """An abstract base of all Environment Services and used to build up mix-ins.""" @classmethod def new(cls, @@ -105,8 +101,9 @@ def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. - This function is usually called by the derived class constructor - just before invoking the constructor of the base class. + + This function is usually called by the derived class constructor just before + invoking the constructor of the base class. """ if isinstance(local_methods, dict): local_methods = local_methods.copy() @@ -170,8 +167,8 @@ def _enter_context(self) -> "Service": """ Enters the context for this particular Service instance. - Called by the base __enter__ method of the Service class so it can be - used with mix-ins and overridden by subclasses. + Called by the base __enter__ method of the Service class so it can be used with + mix-ins and overridden by subclasses. """ assert not self._in_context self._in_context = True @@ -183,8 +180,8 @@ def _exit_context(self, ex_type: Optional[Type[BaseException]], """ Exits the context for this particular Service instance. - Called by the base __enter__ method of the Service class so it can be - used with mix-ins and overridden by subclasses. + Called by the base __enter__ method of the Service class so it can be used with + mix-ins and overridden by subclasses. """ # pylint: disable=unused-argument assert self._in_context @@ -192,10 +189,9 @@ def _exit_context(self, ex_type: Optional[Type[BaseException]], return False def _validate_json_config(self, config: dict) -> None: - """ - Reconstructs a basic json config that this class might have been - instantiated from in order to validate configs provided outside the - file loading mechanism. + """Reconstructs a basic json config that this class might have been instantiated + from in order to validate configs provided outside the file loading + mechanism. """ if self.__class__ == Service: # Skip over the case where instantiate a bare base Service class in order to build up a mix-in. @@ -212,9 +208,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}@{hex(id(self))}" def pprint(self) -> str: - """ - Produce a human-readable string listing all public methods of the service. - """ + """Produce a human-readable string listing all public methods of the service.""" return f"{self} ::\n" + "\n".join( f' "{key}": {getattr(val, "__self__", "stand-alone")}' for (key, val) in self._service_methods.items() diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index cac3216d61..9532b8388b 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -2,10 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Helper functions to load, instantiate, and serialize Python objects -that encapsulate benchmark environments, tunable parameters, and -service functions. +"""Helper functions to load, instantiate, and serialize Python objects that encapsulate +benchmark environments, tunable parameters, and service functions. """ import json # For logging only @@ -55,8 +53,8 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): - """ - Collection of methods to deserialize the Environment, Service, and TunableGroups objects. + """Collection of methods to deserialize the Environment, Service, and TunableGroups + objects. """ BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") @@ -123,8 +121,8 @@ def config_paths(self) -> List[str]: def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ - Prepend the suitable `_config_path` to `path` if the latter is not absolute. - If `_config_path` is `None` or `path` is absolute, return `path` as is. + Prepend the suitable `_config_path` to `path` if the latter is not absolute. If + `_config_path` is `None` or `path` is absolute, return `path` as is. Parameters ---------- @@ -156,9 +154,8 @@ def load_config(self, schema_type: Optional[ConfigSchema], ) -> Dict[str, Any]: """ - Load JSON config file. Search for a file relative to `_config_path` - if the input path is not absolute. - This method is exported to be used as a service. + Load JSON config file. Search for a file relative to `_config_path` if the input + path is not absolute. This method is exported to be used as a service. Parameters ---------- @@ -200,9 +197,8 @@ def prepare_class_load(self, config: Dict[str, Any], global_config: Optional[Dict[str, Any]] = None, parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: """ - Extract the class instantiation parameters from the configuration. - Mix-in the global parameters and resolve the local file system paths, - where it is required. + Extract the class instantiation parameters from the configuration. Mix-in the + global parameters and resolve the local file system paths, where it is required. Parameters ---------- @@ -252,8 +248,7 @@ def build_optimizer(self, *, config: Dict[str, Any], global_config: Optional[Dict[str, Any]] = None) -> Optimizer: """ - Instantiation of mlos_bench Optimizer - that depend on Service and TunableGroups. + Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. A class *MUST* have a constructor that takes four named arguments: (tunables, config, global_config, service) @@ -589,8 +584,8 @@ def load_services(self, json_file_names: Iterable[str], global_config: Optional[Dict[str, Any]] = None, parent: Optional[Service] = None) -> Service: """ - Read the configuration files and bundle all service methods - from those configs into a single Service object. + Read the configuration files and bundle all service methods from those configs + into a single Service object. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index abb87c8b52..bf1361024a 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Local scheduler side Services for mlos_bench. -""" +"""Local scheduler side Services for mlos_bench.""" from mlos_bench.services.local.local_exec import LocalExecService diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 47534be7b1..189a54b210 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Helper functions to run scripts and commands locally on the scheduler side. -""" +"""Helper functions to run scripts and commands locally on the scheduler side.""" import errno import logging @@ -39,9 +37,9 @@ def split_cmdline(cmdline: str) -> Iterable[List[str]]: """ - A single command line may contain multiple commands separated by - special characters (e.g., &&, ||, etc.) so further split the - commandline into an array of subcommand arrays. + A single command line may contain multiple commands separated by special characters + (e.g., &&, ||, etc.) so further split the commandline into an array of subcommand + arrays. Parameters ---------- @@ -74,9 +72,11 @@ def split_cmdline(cmdline: str) -> Iterable[List[str]]: class LocalExecService(TempDirContextService, SupportsLocalExec): """ - Collection of methods to run scripts and commands in an external process - on the node acting as the scheduler. Can be useful for data processing - due to reduced dependency management complications vs the target environment. + Collection of methods to run scripts and commands in an external process on the node + acting as the scheduler. + + Can be useful for data processing due to reduced dependency management complications + vs the target environment. """ def __init__(self, @@ -149,8 +149,8 @@ def local_exec(self, script_lines: Iterable[str], def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: """ - Resolves local script path (first token) in the (sub)command line - tokens to its full path. + Resolves local script path (first token) in the (sub)command line tokens to its + full path. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index a0cf3e0e57..4221754cb0 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Helper functions to work with temp files locally on the scheduler side. -""" +"""Helper functions to work with temp files locally on the scheduler side.""" import abc import logging @@ -21,11 +19,11 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): """ - A *base* service class that provides a method to create a temporary - directory context for local scripts. + A *base* service class that provides a method to create a temporary directory + context for local scripts. - It is inherited by LocalExecService and MockLocalExecService. - This class is not supposed to be used as a standalone service. + It is inherited by LocalExecService and MockLocalExecService. This class is not + supposed to be used as a standalone service. """ def __init__(self, @@ -34,8 +32,8 @@ def __init__(self, parent: Optional[Service] = None, methods: Union[Dict[str, Callable], List[Callable], None] = None): """ - Create a new instance of a service that provides temporary directory context - for local exec service. + Create a new instance of a service that provides temporary directory context for + local exec service. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 61a6c74942..0a148250c3 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Azure-specific benchmark environments for mlos_bench. -""" +"""Azure-specific benchmark environments for mlos_bench.""" from mlos_bench.services.remote.azure.azure_auth import AzureAuthService from mlos_bench.services.remote.azure.azure_fileshare import AzureFileShareService diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 4121446caf..350ecd6e5f 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for managing VMs on Azure. -""" +"""A collection Service functions for managing VMs on Azure.""" import logging from base64 import b64decode @@ -23,9 +21,7 @@ class AzureAuthService(Service, SupportsAuth): - """ - Helper methods to get access to Azure services. - """ + """Helper methods to get access to Azure services.""" _REQ_INTERVAL = 300 # = 5 min @@ -107,9 +103,7 @@ def _init_sp(self) -> None: self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) def get_access_token(self) -> str: - """ - Get the access token from Azure CLI, if expired. - """ + """Get the access token from Azure CLI, if expired.""" # Ensure we are logged as the Service Principal, if provided if "spClientId" in self.config: self._init_sp() @@ -125,7 +119,5 @@ def get_access_token(self) -> str: return self._access_token def get_auth_headers(self) -> dict: - """ - Get the authorization part of HTTP headers for REST API calls. - """ + """Get the authorization part of HTTP headers for REST API calls.""" return {"Authorization": "Bearer " + self.get_access_token()} diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 9f2b504aff..dc2c049c1e 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base class for certain Azure Services classes that do deployments. -""" +"""Base class for certain Azure Services classes that do deployments.""" import abc import json @@ -25,9 +23,7 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): - """ - Helper methods to manage and deploy Azure resources via REST APIs. - """ + """Helper methods to manage and deploy Azure resources via REST APIs.""" _POLL_INTERVAL = 4 # seconds _POLL_TIMEOUT = 300 # seconds @@ -98,9 +94,7 @@ def __init__(self, @property def deploy_params(self) -> dict: - """ - Get the deployment parameters. - """ + """Get the deployment parameters.""" return self._deploy_params @abc.abstractmethod @@ -121,8 +115,8 @@ def _set_default_params(self, params: dict) -> dict: raise NotImplementedError("Should be overridden by subclass.") def _get_session(self, params: dict) -> requests.Session: - """ - Get a session object that includes automatic retries and headers for REST API calls. + """Get a session object that includes automatic retries and headers for REST API + calls. """ total_retries = params.get("requestTotalRetries", self._total_retries) backoff_factor = params.get("requestBackoffFactor", self._backoff_factor) @@ -134,9 +128,7 @@ def _get_session(self, params: dict) -> requests.Session: return session def _get_headers(self) -> dict: - """ - Get the headers for the REST API calls. - """ + """Get the headers for the REST API calls.""" assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @@ -251,8 +243,8 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: """ - Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or FAILED. - Return TIMED_OUT when timing out. + Waits for a pending operation on an Azure resource to resolve to SUCCEEDED or + FAILED. Return TIMED_OUT when timing out. Parameters ---------- @@ -276,8 +268,8 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], loop_status: Status, params: dict) -> Tuple[Status, dict]: """ - Invoke `func` periodically while the status is equal to `loop_status`. - Return TIMED_OUT when timing out. + Invoke `func` periodically while the status is equal to `loop_status`. Return + TIMED_OUT when timing out. Parameters ---------- @@ -327,8 +319,7 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ - Check if Azure deployment exists. - Return SUCCEEDED if true, PENDING otherwise. + Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 6ccd4ba09d..653963922d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection FileShare functions for interacting with Azure File Shares. -""" +"""A collection FileShare functions for interacting with Azure File Shares.""" import logging import os @@ -21,9 +19,7 @@ class AzureFileShareService(FileShareService): - """ - Helper methods for interacting with Azure File Share - """ + """Helper methods for interacting with Azure File Share.""" _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" @@ -100,10 +96,9 @@ def upload(self, params: dict, local_path: str, remote_path: str, recursive: boo def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[str]) -> None: """ - Upload contents from a local path to an Azure file share. - This method is called from `.upload()` above. We need it to avoid exposing - the `seen` parameter and to make `.upload()` match the base class' virtual - method. + Upload contents from a local path to an Azure file share. This method is called + from `.upload()` above. We need it to avoid exposing the `seen` parameter and to + make `.upload()` match the base class' virtual method. Parameters ---------- @@ -142,8 +137,8 @@ def _upload(self, local_path: str, remote_path: str, recursive: bool, seen: Set[ def _remote_makedirs(self, remote_path: str) -> None: """ - Create remote directories for the entire path. - Succeeds even some or all directories along the path already exist. + Create remote directories for the entire path. Succeeds even some or all + directories along the path already exist. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index d65ee02cfd..ff6eb160fd 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for managing virtual networks on Azure. -""" +"""A collection Service functions for managing virtual networks on Azure.""" import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -23,9 +21,7 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): - """ - Helper methods to manage Virtual Networks on Azure. - """ + """Helper methods to manage Virtual Networks on Azure.""" # Azure Compute REST API calls as described in # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index a92d279a6d..b78a069c62 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for configuring SaaS instances on Azure. -""" +"""A collection Service functions for configuring SaaS instances on Azure.""" import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -20,9 +18,7 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): - """ - Helper methods to configure Azure Flex services. - """ + """Helper methods to configure Azure Flex services.""" _REQUEST_TIMEOUT = 5 # seconds @@ -174,9 +170,7 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: )}) def _get_headers(self) -> dict: - """ - Get the headers for the REST API calls. - """ + """Get the headers for the REST API calls.""" assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @@ -218,8 +212,8 @@ def _config_one(self, config: Dict[str, Any], def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ - Update the parameters of an Azure DB service one-by-one. - (If batch API is not available for it). + Update the parameters of an Azure DB service one-by-one. (If batch API is not + available for it). Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index ddce3cc935..384618415d 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for managing VMs on Azure. -""" +"""A collection Service functions for managing VMs on Azure.""" import json import logging @@ -27,9 +25,7 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): - """ - Helper methods to manage VMs on Azure. - """ + """Helper methods to manage VMs on Azure.""" # pylint: disable=too-many-ancestors @@ -277,7 +273,8 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ - Deallocates the VM on Azure by shutting it down then releasing the compute resources. + Deallocates the VM on Azure by shutting it down then releasing the compute + resources. Note: This can cause the VM to arrive on a new host node when its restarted, which may have different performance characteristics. diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index f623cdfcc8..94947f69b0 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection functions for interacting with SSH servers as file shares. -""" +"""A collection functions for interacting with SSH servers as file shares.""" import logging from enum import Enum @@ -20,9 +18,7 @@ class CopyMode(Enum): - """ - Copy mode enum. - """ + """Copy mode enum.""" DOWNLOAD = 1 UPLOAD = 2 @@ -36,7 +32,7 @@ async def _start_file_copy(self, params: dict, mode: CopyMode, recursive: bool = True) -> None: # pylint: disable=too-many-arguments """ - Starts a file copy operation + Starts a file copy operation. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index a650ff0707..26e886b83d 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for managing hosts via SSH. -""" +"""A collection Service functions for managing hosts via SSH.""" import logging from concurrent.futures import Future @@ -23,9 +21,7 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): - """ - Helper methods to manage machines via SSH. - """ + """Helper methods to manage machines via SSH.""" # pylint: disable=too-many-instance-attributes @@ -166,7 +162,8 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: return (Status.FAILED, {"result": result}) def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, dict]: - """_summary_ + """ + _summary_ Parameters ---------- @@ -258,8 +255,8 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: def wait_os_operation(self, params: dict) -> Tuple[Status, dict]: """ - Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. - Return TIMED_OUT when timing out. + Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return + TIMED_OUT when timing out. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 8bc90eb3da..272f908c78 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection functions for interacting with SSH servers as file shares. -""" +"""A collection functions for interacting with SSH servers as file shares.""" import logging import os @@ -45,9 +43,9 @@ class SshClient(asyncssh.SSHClient): """ Wrapper around SSHClient to help provide connection caching and reconnect logic. - Used by the SshService to try and maintain a single connection to hosts, - handle reconnects if possible, and use that to run commands rather than - reconnect for each command. + Used by the SshService to try and maintain a single connection to hosts, handle + reconnects if possible, and use that to run commands rather than reconnect for each + command. """ _CONNECTION_PENDING = 'INIT' @@ -99,9 +97,7 @@ def connection_lost(self, exc: Optional[Exception]) -> None: return super().connection_lost(exc) async def connection(self) -> Optional[SSHClientConnection]: - """ - Waits for and returns the SSHClientConnection to be established or lost. - """ + """Waits for and returns the SSHClientConnection to be established or lost.""" _LOG.debug("%s: Waiting for connection to be available.", current_thread().name) await self._conn_event.wait() _LOG.debug("%s: Connection available for %s", current_thread().name, self._connection_id) @@ -111,6 +107,7 @@ async def connection(self) -> Optional[SSHClientConnection]: class SshClientCache: """ Manages a cache of SshClient connections. + Note: Only one per event loop thread supported. See additional details in SshService comments. """ @@ -129,6 +126,7 @@ def __len__(self) -> int: def enter(self) -> None: """ Manages the cache lifecycle with reference counting. + To be used in the __enter__ method of a caller's context manager. """ self._refcnt += 1 @@ -136,6 +134,7 @@ def enter(self) -> None: def exit(self) -> None: """ Manages the cache lifecycle with reference counting. + To be used in the __exit__ method of a caller's context manager. """ self._refcnt -= 1 @@ -182,18 +181,14 @@ async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientCo return self._cache[connection_id] def cleanup(self) -> None: - """ - Closes all cached connections. - """ + """Closes all cached connections.""" for (connection, _) in self._cache.values(): connection.close() self._cache = {} class SshService(Service, metaclass=ABCMeta): - """ - Base class for SSH services. - """ + """Base class for SSH services.""" # AsyncSSH requires an asyncio event loop to be running to work. # However, running that event loop blocks the main thread. @@ -291,6 +286,7 @@ def _exit_context(self, ex_type: Optional[Type[BaseException]], def clear_client_cache(cls) -> None: """ Clears the cache of client connections. + Note: This may cause in flight operations to fail. """ cls._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.cleanup() diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index 725d0c3306..e691d64514 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Service types for implementing declaring Service behavior for Environments to use in mlos_bench. +"""Service types for implementing declaring Service behavior for Environments to use in +mlos_bench. """ from mlos_bench.services.types.authenticator_type import SupportsAuth diff --git a/mlos_bench/mlos_bench/services/types/authenticator_type.py b/mlos_bench/mlos_bench/services/types/authenticator_type.py index fcec792d7d..6f99dd6bce 100644 --- a/mlos_bench/mlos_bench/services/types/authenticator_type.py +++ b/mlos_bench/mlos_bench/services/types/authenticator_type.py @@ -2,18 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for authentication for the cloud services. -""" +"""Protocol interface for authentication for the cloud services.""" from typing import Protocol, runtime_checkable @runtime_checkable class SupportsAuth(Protocol): - """ - Protocol interface for authentication for the cloud services. - """ + """Protocol interface for authentication for the cloud services.""" def get_access_token(self) -> str: """ diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index 05853da0a9..c0b2d7335b 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for helper functions to lookup and load configs. -""" +"""Protocol interface for helper functions to lookup and load configs.""" from typing import ( TYPE_CHECKING, @@ -30,15 +28,13 @@ @runtime_checkable class SupportsConfigLoading(Protocol): - """ - Protocol interface for helper functions to lookup and load configs. - """ + """Protocol interface for helper functions to lookup and load configs.""" def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ - Prepend the suitable `_config_path` to `path` if the latter is not absolute. - If `_config_path` is `None` or `path` is absolute, return `path` as is. + Prepend the suitable `_config_path` to `path` if the latter is not absolute. If + `_config_path` is `None` or `path` is absolute, return `path` as is. Parameters ---------- @@ -55,9 +51,8 @@ def resolve_path(self, file_path: str, def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: """ - Load JSON config file. Search for a file relative to `_config_path` - if the input path is not absolute. - This method is exported to be used as a service. + Load JSON config file. Search for a file relative to `_config_path` if the input + path is not absolute. This method is exported to be used as a service. Parameters ---------- @@ -141,8 +136,8 @@ def load_services(self, json_file_names: Iterable[str], global_config: Optional[Dict[str, Any]] = None, parent: Optional["Service"] = None) -> "Service": """ - Read the configuration files and bundle all service methods - from those configs into a single Service object. + Read the configuration files and bundle all service methods from those configs + into a single Service object. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 87ec9e49da..607f5cb674 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -2,18 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for file share operations. -""" +"""Protocol interface for file share operations.""" from typing import Protocol, runtime_checkable @runtime_checkable class SupportsFileShareOps(Protocol): - """ - Protocol interface for file share operations. - """ + """Protocol interface for file share operations.""" def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: """ diff --git a/mlos_bench/mlos_bench/services/types/host_ops_type.py b/mlos_bench/mlos_bench/services/types/host_ops_type.py index 5418f8b1d3..166406714d 100644 --- a/mlos_bench/mlos_bench/services/types/host_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/host_ops_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Host/VM boot operations. -""" +"""Protocol interface for Host/VM boot operations.""" from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsHostOps(Protocol): - """ - Protocol interface for Host/VM boot operations. - """ + """Protocol interface for Host/VM boot operations.""" def start_host(self, params: dict) -> Tuple["Status", dict]: """ diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 77b481e48e..1be95aab22 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Host/VM provisioning operations. -""" +"""Protocol interface for Host/VM provisioning operations.""" from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsHostProvisioning(Protocol): - """ - Protocol interface for Host/VM provisioning operations. - """ + """Protocol interface for Host/VM provisioning operations.""" def provision_host(self, params: dict) -> Tuple["Status", dict]: """ @@ -74,7 +70,8 @@ def deprovision_host(self, params: dict) -> Tuple["Status", dict]: def deallocate_host(self, params: dict) -> Tuple["Status", dict]: """ - Deallocates the Host/VM by shutting it down then releasing the compute resources. + Deallocates the Host/VM by shutting it down then releasing the compute + resources. Note: This can cause the VM to arrive on a new host node when its restarted, which may have different performance characteristics. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index c4c5f01ddc..1c8f5f627e 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -2,9 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Service types that provide helper functions to run -scripts and commands locally on the scheduler side. +"""Protocol interface for Service types that provide helper functions to run scripts and +commands locally on the scheduler side. """ import contextlib @@ -25,11 +24,11 @@ @runtime_checkable class SupportsLocalExec(Protocol): """ - Protocol interface for a collection of methods to run scripts and commands - in an external process on the node acting as the scheduler. Can be useful - for data processing due to reduced dependency management complications vs - the target environment. - Used in LocalEnv and provided by LocalExecService. + Protocol interface for a collection of methods to run scripts and commands in an + external process on the node acting as the scheduler. + + Can be useful for data processing due to reduced dependency management complications + vs the target environment. Used in LocalEnv and provided by LocalExecService. """ def local_exec(self, script_lines: Iterable[str], diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index fb753aa21c..5c1812f1f0 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Network provisioning operations. -""" +"""Protocol interface for Network provisioning operations.""" from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsNetworkProvisioning(Protocol): - """ - Protocol interface for Network provisioning operations. - """ + """Protocol interface for Network provisioning operations.""" def provision_network(self, params: dict) -> Tuple["Status", dict]: """ diff --git a/mlos_bench/mlos_bench/services/types/os_ops_type.py b/mlos_bench/mlos_bench/services/types/os_ops_type.py index 6d5cea34e5..8b727f87a6 100644 --- a/mlos_bench/mlos_bench/services/types/os_ops_type.py +++ b/mlos_bench/mlos_bench/services/types/os_ops_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Host/OS operations. -""" +"""Protocol interface for Host/OS operations.""" from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsOSOps(Protocol): - """ - Protocol interface for Host/OS operations. - """ + """Protocol interface for Host/OS operations.""" def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: """ @@ -56,8 +52,8 @@ def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: def wait_os_operation(self, params: dict) -> Tuple["Status", dict]: """ - Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. - Return TIMED_OUT when timing out. + Waits for a pending operation on an OS to resolve to SUCCEEDED or FAILED. Return + TIMED_OUT when timing out. Parameters ---------- diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index c653e10c2b..c25bc7b0ba 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for configuring cloud services. -""" +"""Protocol interface for configuring cloud services.""" from typing import TYPE_CHECKING, Any, Dict, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsRemoteConfig(Protocol): - """ - Protocol interface for configuring cloud services. - """ + """Protocol interface for configuring cloud services.""" def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]: diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index 096cb3c675..cba9e31b22 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -2,9 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for Service types that provide helper functions to run -scripts on a remote host OS. +"""Protocol interface for Service types that provide helper functions to run scripts on +a remote host OS. """ from typing import TYPE_CHECKING, Iterable, Protocol, Tuple, runtime_checkable @@ -15,9 +14,8 @@ @runtime_checkable class SupportsRemoteExec(Protocol): - """ - Protocol interface for Service types that provide helper functions to run - scripts on a remote host OS. + """Protocol interface for Service types that provide helper functions to run scripts + on a remote host OS. """ def remote_exec(self, script: Iterable[str], config: dict, diff --git a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py index 19747b3f12..69d24f3fd3 100644 --- a/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/vm_provisioner_type.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Protocol interface for VM provisioning operations. -""" +"""Protocol interface for VM provisioning operations.""" from typing import TYPE_CHECKING, Protocol, Tuple, runtime_checkable @@ -14,9 +12,7 @@ @runtime_checkable class SupportsVMOps(Protocol): - """ - Protocol interface for VM provisioning operations. - """ + """Protocol interface for VM provisioning operations.""" def vm_provision(self, params: dict) -> Tuple["Status", dict]: """ @@ -122,8 +118,8 @@ def vm_deprovision(self, params: dict) -> Tuple["Status", dict]: def wait_vm_operation(self, params: dict) -> Tuple["Status", dict]: """ - Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. - Return TIMED_OUT when timing out. + Waits for a pending operation on a VM to resolve to SUCCEEDED or FAILED. Return + TIMED_OUT when timing out. Parameters ---------- diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index 9ae5c80f36..a5bfeb7145 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Interfaces to the storage backends for OS Autotune. -""" +"""Interfaces to the storage backends for OS Autotune.""" from mlos_bench.storage.base_storage import Storage from mlos_bench.storage.storage_factory import from_config diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index ce07e44e2b..a6cb7d496a 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base interface for accessing the stored benchmark experiment data. -""" +"""Base interface for accessing the stored benchmark experiment data.""" from abc import ABCMeta, abstractmethod from distutils.util import strtobool # pylint: disable=deprecated-module @@ -46,16 +44,12 @@ def __init__(self, *, @property def experiment_id(self) -> str: - """ - ID of the experiment. - """ + """ID of the experiment.""" return self._experiment_id @property def description(self) -> str: - """ - Description of the experiment. - """ + """Description of the experiment.""" return self._description @property @@ -125,7 +119,8 @@ def tunable_config_trial_groups(self) -> Dict[int, "TunableConfigTrialGroupData" @property def default_tunable_config_id(self) -> Optional[int]: """ - Retrieves the (tunable) config id for the default tunable values for this experiment. + Retrieves the (tunable) config id for the default tunable values for this + experiment. Note: this is by *default* the first trial executed for this experiment. However, it is currently possible that the user changed the tunables config diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 2165fa706f..39b3bf851b 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base interface for saving and restoring the benchmark data. -""" +"""Base interface for saving and restoring the benchmark data.""" import logging from abc import ABCMeta, abstractmethod @@ -25,9 +23,8 @@ class Storage(metaclass=ABCMeta): - """ - An abstract interface between the benchmarking framework - and storage systems (e.g., SQLite or MLFLow). + """An abstract interface between the benchmarking framework and storage systems + (e.g., SQLite or MLFLow). """ def __init__(self, @@ -49,10 +46,9 @@ def __init__(self, self._global_config = global_config or {} def _validate_json_config(self, config: dict) -> None: - """ - Reconstructs a basic json config that this class might have been - instantiated from in order to validate configs provided outside the - file loading mechanism. + """Reconstructs a basic json config that this class might have been instantiated + from in order to validate configs provided outside the file loading + mechanism. """ json_config: dict = { "class": self.__class__.__module__ + "." + self.__class__.__name__, @@ -113,6 +109,7 @@ class Experiment(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ Base interface for storing the results of the experiment. + This class is instantiated in the `Storage.experiment()` method. """ @@ -169,7 +166,8 @@ def __repr__(self) -> str: def _setup(self) -> None: """ - Create a record of the new experiment or find an existing one in the storage. + Create a record of the new experiment or find an existing one in the + storage. This method is called by `Storage.Experiment.__enter__()`. """ @@ -188,36 +186,34 @@ def _teardown(self, is_ok: bool) -> None: @property def experiment_id(self) -> str: - """Get the Experiment's ID""" + """Get the Experiment's ID.""" return self._experiment_id @property def trial_id(self) -> int: - """Get the current Trial ID""" + """Get the current Trial ID.""" return self._trial_id @property def description(self) -> str: - """Get the Experiment's description""" + """Get the Experiment's description.""" return self._description @property def tunables(self) -> TunableGroups: - """Get the Experiment's tunables""" + """Get the Experiment's tunables.""" return self._tunables @property def opt_targets(self) -> Dict[str, Literal["min", "max"]]: - """ - Get the Experiment's optimization targets and directions - """ + """Get the Experiment's optimization targets and directions.""" return self._opt_targets @abstractmethod def merge(self, experiment_ids: List[str]) -> None: """ - Merge in the results of other (compatible) experiments trials. - Used to help warm up the optimizer for this experiment. + Merge in the results of other (compatible) experiments trials. Used to help + warm up the optimizer for this experiment. Parameters ---------- @@ -227,9 +223,7 @@ def merge(self, experiment_ids: List[str]) -> None: @abstractmethod def load_tunable_config(self, config_id: int) -> Dict[str, Any]: - """ - Load tunable values for a given config ID. - """ + """Load tunable values for a given config ID.""" @abstractmethod def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: @@ -271,8 +265,8 @@ def load(self, last_trial_id: int = -1, @abstractmethod def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: """ - Return an iterator over the pending trials that are scheduled to run - on or before the specified timestamp. + Return an iterator over the pending trials that are scheduled to run on or + before the specified timestamp. Parameters ---------- @@ -314,6 +308,7 @@ class Trial(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """ Base interface for storing the results of a single run of the experiment. + This class is instantiated in the `Storage.Experiment.trial()` method. """ @@ -333,29 +328,23 @@ def __repr__(self) -> str: @property def trial_id(self) -> int: - """ - ID of the current trial. - """ + """ID of the current trial.""" return self._trial_id @property def tunable_config_id(self) -> int: - """ - ID of the current trial (tunable) configuration. - """ + """ID of the current trial (tunable) configuration.""" return self._tunable_config_id @property def opt_targets(self) -> Dict[str, Literal["min", "max"]]: - """ - Get the Trial's optimization targets and directions. - """ + """Get the Trial's optimization targets and directions.""" return self._opt_targets @property def tunables(self) -> TunableGroups: """ - Tunable parameters of the current trial + Tunable parameters of the current trial. (e.g., application Environment's "config") """ @@ -363,8 +352,8 @@ def tunables(self) -> TunableGroups: def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """ - Produce a copy of the global configuration updated - with the parameters of the current trial. + Produce a copy of the global configuration updated with the parameters of + the current trial. Note: this is not the target Environment's "config" (i.e., tunable params), but rather the internal "config" which consists of a diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index b3b2bed86a..f9f7b93322 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base interface for accessing the stored benchmark trial data. -""" +"""Base interface for accessing the stored benchmark trial data.""" from abc import ABCMeta, abstractmethod from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, Optional @@ -27,8 +25,8 @@ class TrialData(metaclass=ABCMeta): """ Base interface for accessing the stored experiment benchmark trial data. - A trial is a single run of an experiment with a given configuration (e.g., set - of tunable parameters). + A trial is a single run of an experiment with a given configuration (e.g., set of + tunable parameters). """ def __init__(self, *, @@ -57,44 +55,32 @@ def __eq__(self, other: Any) -> bool: @property def experiment_id(self) -> str: - """ - ID of the experiment this trial belongs to. - """ + """ID of the experiment this trial belongs to.""" return self._experiment_id @property def trial_id(self) -> int: - """ - ID of the trial. - """ + """ID of the trial.""" return self._trial_id @property def ts_start(self) -> datetime: - """ - Start timestamp of the trial (UTC). - """ + """Start timestamp of the trial (UTC).""" return self._ts_start @property def ts_end(self) -> Optional[datetime]: - """ - End timestamp of the trial (UTC). - """ + """End timestamp of the trial (UTC).""" return self._ts_end @property def status(self) -> Status: - """ - Status of the trial. - """ + """Status of the trial.""" return self._status @property def tunable_config_id(self) -> int: - """ - ID of the (tunable) configuration of the trial. - """ + """ID of the (tunable) configuration of the trial.""" return self._tunable_config_id @property @@ -114,9 +100,7 @@ def tunable_config(self) -> TunableConfigData: @property @abstractmethod def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": - """ - Retrieve the trial's (tunable) config trial group data from the storage. - """ + """Retrieve the trial's (tunable) config trial group data from the storage.""" @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0dce110b1b..0d58c20dc8 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base interface for accessing the stored benchmark (tunable) config data. -""" +"""Base interface for accessing the stored benchmark (tunable) config data.""" from abc import ABCMeta, abstractmethod from typing import Any, Dict, Optional @@ -35,9 +33,7 @@ def __eq__(self, other: Any) -> bool: @property def tunable_config_id(self) -> int: - """ - Unique ID of the (tunable) configuration. - """ + """Unique ID of the (tunable) configuration.""" return self._tunable_config_id @property diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 18c50035a9..62c01c3266 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base interface for accessing the stored benchmark config trial group data. -""" +"""Base interface for accessing the stored benchmark config trial group data.""" from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Any, Dict, Optional @@ -19,8 +17,8 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): """ - Base interface for accessing the stored experiment benchmark tunable config - trial group data. + Base interface for accessing the stored experiment benchmark tunable config trial + group data. A (tunable) config is used to define an instance of values for a set of tunable parameters for a given experiment and can be used by one or more trial instances @@ -38,23 +36,17 @@ def __init__(self, *, @property def experiment_id(self) -> str: - """ - ID of the experiment. - """ + """ID of the experiment.""" return self._experiment_id @property def tunable_config_id(self) -> int: - """ - ID of the config. - """ + """ID of the config.""" return self._tunable_config_id @abstractmethod def _get_tunable_config_trial_group_id(self) -> int: - """ - Retrieve the trial's config_trial_group_id from the storage. - """ + """Retrieve the trial's config_trial_group_id from the storage.""" raise NotImplementedError("subclass must implement") @property @@ -83,7 +75,8 @@ def __eq__(self, other: Any) -> bool: @abstractmethod def tunable_config(self) -> TunableConfigData: """ - Retrieve the (tunable) config data for this (tunable) config trial group from the storage. + Retrieve the (tunable) config data for this (tunable) config trial group from + the storage. Returns ------- @@ -94,7 +87,8 @@ def tunable_config(self) -> TunableConfigData: @abstractmethod def trials(self) -> Dict[int, "TrialData"]: """ - Retrieve the trials' data for this (tunable) config trial group from the storage. + Retrieve the trials' data for this (tunable) config trial group from the + storage. Returns ------- @@ -106,7 +100,8 @@ def trials(self) -> Dict[int, "TrialData"]: @abstractmethod def results_df(self) -> pandas.DataFrame: """ - Retrieve all results for this (tunable) config trial group as a single DataFrame. + Retrieve all results for this (tunable) config trial group as a single + DataFrame. Returns ------- diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index 735e21bcaf..86fd6de291 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Interfaces to the SQL-based storage backends for OS Autotune. -""" +"""Interfaces to the SQL-based storage backends for OS Autotune.""" from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index c7ee73a3bc..fed66b339d 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common SQL methods for accessing the stored benchmark data. -""" +"""Common SQL methods for accessing the stored benchmark data.""" from typing import Dict, Optional import pandas @@ -23,8 +21,9 @@ def get_trials( experiment_id: str, tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: """ - Gets TrialData for the given experiment_data and optionally additionally - restricted by tunable_config_id. + Gets TrialData for the given experiment_data and optionally additionally restricted + by tunable_config_id. + Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. """ from mlos_bench.storage.sql.trial_data import ( @@ -65,8 +64,9 @@ def get_results_df( experiment_id: str, tunable_config_id: Optional[int] = None) -> pandas.DataFrame: """ - Gets TrialData for the given experiment_data and optionally additionally - restricted by tunable_config_id. + Gets TrialData for the given experiment_data and optionally additionally restricted + by tunable_config_id. + Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. """ # pylint: disable=too-many-locals diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 58ee3dddb5..c96cd503be 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Saving and restoring the benchmark data using SQLAlchemy. -""" +"""Saving and restoring the benchmark data using SQLAlchemy.""" import hashlib import logging @@ -25,9 +23,7 @@ class Experiment(Storage.Experiment): - """ - Logic for retrieving and storing the results of a single experiment. - """ + """Logic for retrieving and storing the results of a single experiment.""" def __init__(self, *, engine: Engine, @@ -169,6 +165,7 @@ def load(self, last_trial_id: int = -1, def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> Dict[str, Any]: """ Helper method to retrieve key-value pairs from the database. + (E.g., configurations, results, and telemetry). """ cur_result: CursorResult[Tuple[str, Any]] = conn.execute( @@ -232,8 +229,9 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: """ - Get the config ID for the given tunables. If the config does not exist, - create a new record for it. + Get the config ID for the given tunables. + + If the config does not exist, create a new record for it. """ config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() cur_config = conn.execute(self._schema.config.select().where( diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index eaa6e1041f..b92885d1fd 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -An interface to access the experiment benchmark data stored in SQL DB. -""" +"""An interface to access the experiment benchmark data stored in SQL DB.""" import logging from typing import Dict, Literal, Optional @@ -127,7 +125,8 @@ def tunable_configs(self) -> Dict[int, TunableConfigData]: @property def default_tunable_config_id(self) -> Optional[int]: """ - Retrieves the (tunable) config id for the default tunable values for this experiment. + Retrieves the (tunable) config id for the default tunable values for this + experiment. Note: this is by *default* the first trial executed for this experiment. However, it is currently possible that the user changed the tunables config diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 9a1eca2744..3443c9b810 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -DB schema definition. -""" +"""DB schema definition.""" import logging from typing import Any, List @@ -49,9 +47,7 @@ def __repr__(self) -> str: class DbSchema: - """ - A class to define and create the DB schema. - """ + """A class to define and create the DB schema.""" # This class is internal to SqlStorage and is mostly a struct # for all DB tables, so it's ok to disable the warnings. @@ -64,9 +60,7 @@ class DbSchema: _STATUS_LEN = 16 def __init__(self, engine: Engine): - """ - Declare the SQLAlchemy schema for the database. - """ + """Declare the SQLAlchemy schema for the database.""" _LOG.info("Create the DB schema for: %s", engine) self._engine = engine # TODO: bind for automatic schema updates? (#649) @@ -204,17 +198,15 @@ def __init__(self, engine: Engine): _LOG.debug("Schema: %s", self._meta) def create(self) -> 'DbSchema': - """ - Create the DB schema. - """ + """Create the DB schema.""" _LOG.info("Create the DB schema") self._meta.create_all(self._engine) return self def __repr__(self) -> str: """ - Produce a string with all SQL statements required to create the schema - from scratch in current SQL dialect. + Produce a string with all SQL statements required to create the schema from + scratch in current SQL dialect. That is, return a collection of CREATE TABLE statements and such. NOTE: this method is quite heavy! We use it only once at startup diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index bde38575bd..4ac5116b70 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Saving and restoring the benchmark data in SQL database. -""" +"""Saving and restoring the benchmark data in SQL database.""" import logging from typing import Dict, Literal, Optional @@ -23,9 +21,7 @@ class SqlStorage(Storage): - """ - An implementation of the Storage interface using SQLAlchemy backend. - """ + """An implementation of the Storage interface using SQLAlchemy backend.""" def __init__(self, config: dict, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 7ac7958845..2a43c2c671 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Saving and updating benchmark data using SQLAlchemy backend. -""" +"""Saving and updating benchmark data using SQLAlchemy backend.""" import logging from datetime import datetime @@ -23,9 +21,7 @@ class Trial(Storage.Trial): - """ - Store the results of a single run of the experiment in SQL database. - """ + """Store the results of a single run of the experiment in SQL database.""" def __init__(self, *, engine: Engine, @@ -136,6 +132,7 @@ def update_telemetry(self, status: Status, timestamp: datetime, def _update_status(self, conn: Connection, status: Status, timestamp: datetime) -> None: """ Insert a new status record into the database. + This call is idempotent. """ # Make sure to convert the timestamp to UTC before storing it in the database. diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 5a6f8a5ee8..18fc0b46ff 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -An interface to access the benchmark trial data stored in SQL DB. -""" +"""An interface to access the benchmark trial data stored in SQL DB.""" from datetime import datetime from typing import TYPE_CHECKING, Optional @@ -25,9 +23,7 @@ class TrialSqlData(TrialData): - """ - An interface to access the trial data stored in the SQL DB. - """ + """An interface to access the trial data stored in the SQL DB.""" def __init__(self, *, engine: Engine, @@ -61,8 +57,8 @@ def tunable_config(self) -> TunableConfigData: @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": - """ - Retrieve the trial's tunable config group configuration data from the storage. + """Retrieve the trial's tunable config group configuration data from the + storage. """ # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.tunable_config_trial_group_data import ( @@ -74,9 +70,7 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @property def results_df(self) -> pandas.DataFrame: - """ - Retrieve the trials' results from the storage. - """ + """Retrieve the trials' results from the storage.""" with self._engine.connect() as conn: cur_results = conn.execute( self._schema.trial_result.select().where( @@ -92,9 +86,7 @@ def results_df(self) -> pandas.DataFrame: @property def telemetry_df(self) -> pandas.DataFrame: - """ - Retrieve the trials' telemetry from the storage. - """ + """Retrieve the trials' telemetry from the storage.""" with self._engine.connect() as conn: cur_telemetry = conn.execute( self._schema.trial_telemetry.select().where( diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index e484979790..616d5fe823 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -An interface to access the tunable config data stored in SQL DB. -""" +"""An interface to access the tunable config data stored in SQL DB.""" import pandas from sqlalchemy import Engine diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index eb389a5940..457e81e7c0 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -An interface to access the tunable config trial group data stored in SQL DB. -""" +"""An interface to access the tunable config trial group data stored in SQL DB.""" from typing import TYPE_CHECKING, Dict, Optional @@ -25,8 +23,8 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): """ - SQL interface for accessing the stored experiment benchmark tunable config - trial group data. + SQL interface for accessing the stored experiment benchmark tunable config trial + group data. A (tunable) config is used to define an instance of values for a set of tunable parameters for a given experiment and can be used by one or more trial instances @@ -48,9 +46,7 @@ def __init__(self, *, self._schema = schema def _get_tunable_config_trial_group_id(self) -> int: - """ - Retrieve the trial's tunable_config_trial_group_id from the storage. - """ + """Retrieve the trial's tunable_config_trial_group_id from the storage.""" with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( self._schema.trial.select().with_only_columns( @@ -79,7 +75,8 @@ def tunable_config(self) -> TunableConfigData: @property def trials(self) -> Dict[int, "TrialData"]: """ - Retrieve the trials' data for this (tunable) config trial group from the storage. + Retrieve the trials' data for this (tunable) config trial group from the + storage. Returns ------- diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 220f3d812c..d1b2547876 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Factory method to create a new Storage instance from configs. -""" +"""Factory method to create a new Storage instance from configs.""" from typing import Any, Dict, List, Optional diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index a4610da8de..1ac48b4fab 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Utility functions for the storage subsystem. -""" +"""Utility functions for the storage subsystem.""" from typing import Dict, Optional diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 26aa142441..dee543357f 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -4,6 +4,7 @@ # """ Tests for mlos_bench. + Used to make mypy happy about multiple conftest.py modules. """ import filecmp @@ -59,9 +60,7 @@ def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]: - """ - Gets the full class name from the given name or None on error. - """ + """Gets the full class name from the given name or None on error.""" if class_name is None: return None try: @@ -72,9 +71,7 @@ def try_resolve_class_name(class_name: Optional[str]) -> Optional[str]: def check_class_name(obj: object, expected_class_name: str) -> bool: - """ - Compares the class name of the given object with the given name. - """ + """Compares the class name of the given object with the given name.""" full_class_name = obj.__class__.__module__ + "." + obj.__class__.__name__ return full_class_name == try_resolve_class_name(expected_class_name) @@ -119,15 +116,13 @@ def resolve_host_name(host: str) -> Optional[str]: def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ - Compare two directories recursively. Files in each directory are - assumed to be equal if their names and contents are equal. + Compare two directories recursively. Files in each directory are assumed to be equal + if their names and contents are equal. - @param dir1: First directory path - @param dir2: Second directory path + @param dir1: First directory path @param dir2: Second directory path - @return: True if the directory trees are the same and - there were no errors while accessing the directories or files, - False otherwise. + @return: True if the directory trees are the same and there were no errors while + accessing the directories or files, False otherwise. """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index 4d728b4037..ecd1c1ba58 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Helper functions for config example loading tests. -""" +"""Helper functions for config example loading tests.""" import os import sys @@ -24,7 +22,8 @@ def locate_config_examples(root_dir: str, config_examples_dir: str, examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: - """Locates all config examples in the given directory. + """ + Locates all config examples in the given directory. Parameters ---------- diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index e1e26d7d8b..1bea4f4369 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading storage config examples. -""" +"""Tests for loading storage config examples.""" import logging import sys diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index fdcb3370cf..6f8cebb910 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test fixtures for mlos_bench config loader tests. -""" +"""Test fixtures for mlos_bench config loader tests.""" import sys diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 42925a0a5d..1b9103c5af 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading environment config examples. -""" +"""Tests for loading environment config examples.""" import logging from typing import List diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index 4d8c93fdff..708bb0f55c 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading globals config examples. -""" +"""Tests for loading globals config examples.""" import logging from typing import List diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index 6cb6253dea..ad4dae94f8 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading optimizer config examples. -""" +"""Tests for loading optimizer config examples.""" import logging from typing import List diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index e4264003e1..5f2f24e519 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common tests for config schemas and their validation and test cases. -""" +"""Common tests for config schemas and their validation and test cases.""" import os from copy import deepcopy @@ -22,9 +20,7 @@ # A dataclass to make pylint happy. @dataclass class SchemaTestType: - """ - The different type of schema test cases we expect to have. - """ + """The different type of schema test cases we expect to have.""" test_case_type: str test_case_subtypes: Set[str] @@ -42,9 +38,7 @@ def __hash__(self) -> int: @dataclass class SchemaTestCaseInfo(): - """ - Some basic info about a schema test case. - """ + """Some basic info about a schema test case.""" config: Dict[str, Any] test_case_file: str @@ -56,9 +50,8 @@ def __hash__(self) -> int: def check_schema_dir_layout(test_cases_root: str) -> None: - """ - Makes sure the directory layout matches what we expect so we aren't missing - any extra configs or test cases. + """Makes sure the directory layout matches what we expect so we aren't missing any + extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): if test_case_dir == 'README.md': @@ -74,9 +67,7 @@ def check_schema_dir_layout(test_cases_root: str) -> None: @dataclass class TestCases: - """ - A container for test cases by type. - """ + """A container for test cases by type.""" by_path: Dict[str, SchemaTestCaseInfo] by_type: Dict[str, Dict[str, SchemaTestCaseInfo]] @@ -84,9 +75,7 @@ class TestCases: def get_schema_test_cases(test_cases_root: str) -> TestCases: - """ - Gets a dict of schema test cases from the given root. - """ + """Gets a dict of schema test cases from the given root.""" test_cases = TestCases(by_path={}, by_type={x: {} for x in _SCHEMA_TEST_TYPES}, by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) @@ -143,8 +132,8 @@ def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: C def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: - """ - Checks that the config fails to validate if extra params are present in certain places. + """Checks that the config fails to validate if extra params are present in certain + places. """ config = deepcopy(test_case.config) schema_type.validate(config) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 5dd1666008..3ef2b56654 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for CLI schema validation. -""" +"""Tests for CLI schema validation.""" from os import path @@ -28,9 +26,7 @@ @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the CLI config validates against the schema. - """ + """Checks that the CLI config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. @@ -41,8 +37,8 @@ def test_cli_configs_against_schema(test_case_name: str) -> None: @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_cli_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the cli config fails to validate if extra params are present in certain places. + """Checks that the cli config fails to validate if extra params are present in + certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index dc3cd40425..84381c4a6b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for environment schema validation. -""" +"""Tests for environment schema validation.""" from os import path @@ -50,8 +48,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("env_class", expected_environment_class_names) def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_class: str) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_bench Environment type. + """Checks to see if there is a given type of test case for the given mlos_bench + Environment type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): if try_resolve_class_name(test_case.config.get("class")) == env_class: @@ -64,17 +62,15 @@ def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_c @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the environment config validates against the schema. - """ + """Checks that the environment config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.ENVIRONMENT) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_environment_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the environment config fails to validate if extra params are present in certain places. + """Checks that the environment config fails to validate if extra params are present + in certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index 5045bf510b..f5a5b83f9f 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for CLI schema validation. -""" +"""Tests for CLI schema validation.""" from os import path @@ -27,9 +25,7 @@ @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the CLI config validates against the schema. - """ + """Checks that the CLI config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index e9ee653644..0c05cf7323 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for optimizer schema validation. -""" +"""Tests for optimizer schema validation.""" from os import path from typing import Optional @@ -51,8 +49,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_bench optimizer type. + """Checks to see if there is a given type of test case for the given mlos_bench + optimizer type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: @@ -69,8 +67,8 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType]) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_core optimizer type. + """Checks to see if there is a given type of test case for the given mlos_core + optimizer type. """ optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): @@ -90,8 +88,8 @@ def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, @pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_core space adapter type. + """Checks to see if there is a given type of test case for the given mlos_core space + adapter type. """ space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): @@ -110,17 +108,15 @@ def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the optimizer config validates against the schema. - """ + """Checks that the optimizer config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.OPTIMIZER) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the optimizer config fails to validate if extra params are present in certain places. + """Checks that the optimizer config fails to validate if extra params are present in + certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 8fccba8bc7..0908252971 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for schedulers schema validation. -""" +"""Tests for schedulers schema validation.""" from os import path @@ -41,8 +39,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_bench scheduler type. + """Checks to see if there is a given type of test case for the given mlos_bench + scheduler type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: @@ -55,17 +53,15 @@ def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_be @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_scheduler_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the scheduler config validates against the schema. - """ + """Checks that the scheduler config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SCHEDULER) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the scheduler config fails to validate if extra params are present in certain places. + """Checks that the scheduler config fails to validate if extra params are present in + certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 64c6fccccd..92b8e69110 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for service schema validation. -""" +"""Tests for service schema validation.""" from os import path from typing import Any, Dict, List @@ -55,8 +53,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("service_class", expected_service_class_names) def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_class: str) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_bench Service type. + """Checks to see if there is a given type of test case for the given mlos_bench + Service type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] @@ -77,17 +75,15 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the service config validates against the schema. - """ + """Checks that the service config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.SERVICE) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_service_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the service config fails to validate if extra params are present in certain places. + """Checks that the service config fails to validate if extra params are present in + certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 9b362b5e0d..fec23c8284 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for storage schema validation. -""" +"""Tests for storage schema validation.""" from os import path @@ -39,8 +37,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: - """ - Checks to see if there is a given type of test case for the given mlos_bench storage type. + """Checks to see if there is a given type of test case for the given mlos_bench + storage type. """ for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: @@ -53,17 +51,15 @@ def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_benc @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the storage config validates against the schema. - """ + """Checks that the storage config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.STORAGE) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_type["good"])) def test_storage_configs_with_extra_param(test_case_name: str) -> None: - """ - Checks that the storage config fails to validate if extra params are present in certain places. + """Checks that the storage config fails to validate if extra params are present in + certain places. """ check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index a6d0de9313..762314961e 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for tunable params schema validation. -""" +"""Tests for tunable params schema validation.""" from os import path @@ -27,8 +25,6 @@ @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the tunable params config validates against the schema. - """ + """Checks that the tunable params config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_PARAMS) check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index d871eaa212..0426373a90 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for tunable values schema validation. -""" +"""Tests for tunable values schema validation.""" from os import path @@ -27,9 +25,7 @@ @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: - """ - Checks that the tunable values config validates against the schema. - """ + """Checks that the tunable values config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 32034eb11c..b5ac6380ed 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading service config examples. -""" +"""Tests for loading service config examples.""" import logging from typing import List diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index 2f9773a9b0..ff2c8c6e5b 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for loading storage config examples. -""" +"""Tests for loading storage config examples.""" import logging from typing import List diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 58359eb983..2fc5268c26 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common fixtures for mock TunableGroups and Environment objects. -""" +"""Common fixtures for mock TunableGroups and Environment objects.""" import os from typing import Any, Generator, List @@ -31,9 +29,7 @@ @pytest.fixture def mock_env(tunable_groups: TunableGroups) -> MockEnv: - """ - Test fixture for MockEnv. - """ + """Test fixture for MockEnv.""" return MockEnv( name="Test Env", config={ @@ -48,9 +44,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: @pytest.fixture def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: - """ - Test fixture for MockEnv. - """ + """Test fixture for MockEnv.""" return MockEnv( name="Test Env No Noise", config={ @@ -105,8 +99,8 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: """ - Gets a pytest session lock for xdist workers to mark when they're using the - docker services. + Gets a pytest session lock for xdist workers to mark when they're using the docker + services. Yields ------ @@ -119,8 +113,8 @@ def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterP @pytest.fixture(scope="session") def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessLock: """ - Gets a pytest session lock between xdist workers for the docker - setup/teardown operations. + Gets a pytest session lock between xdist workers for the docker setup/teardown + operations. Yields ------ @@ -139,8 +133,8 @@ def locked_docker_services( docker_setup_teardown_lock: InterProcessLock, docker_services_lock: InterProcessReaderWriterLock, ) -> Generator[DockerServices, Any, None]: - """ - A locked version of the docker_services fixture to implement xdist single instance locking. + """A locked version of the docker_services fixture to implement xdist single + instance locking. """ # pylint: disable=too-many-arguments # Mark the services as in use with the reader lock. diff --git a/mlos_bench/mlos_bench/tests/dict_templater_test.py b/mlos_bench/mlos_bench/tests/dict_templater_test.py index 63219d9246..4b64f50fd4 100644 --- a/mlos_bench/mlos_bench/tests/dict_templater_test.py +++ b/mlos_bench/mlos_bench/tests/dict_templater_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for DictTemplater class. -""" +"""Unit tests for DictTemplater class.""" from copy import deepcopy from typing import Any, Dict @@ -46,9 +44,7 @@ def source_template_dict() -> Dict[str, Any]: def test_no_side_effects(source_template_dict: Dict[str, Any]) -> None: - """ - Test that the templater does not modify the source dictionary. - """ + """Test that the templater does not modify the source dictionary.""" source_template_dict_copy = deepcopy(source_template_dict) results = DictTemplater(source_template_dict_copy).expand_vars() assert results @@ -56,9 +52,7 @@ def test_no_side_effects(source_template_dict: Dict[str, Any]) -> None: def test_secondary_expansion(source_template_dict: Dict[str, Any]) -> None: - """ - Test that internal expansions work as expected. - """ + """Test that internal expansions work as expected.""" results = DictTemplater(source_template_dict).expand_vars() assert results == { "extra_str-ref": "$extra_str-ref", @@ -85,9 +79,7 @@ def test_secondary_expansion(source_template_dict: Dict[str, Any]) -> None: def test_os_env_expansion(source_template_dict: Dict[str, Any]) -> None: - """ - Test that expansions from OS env work as expected. - """ + """Test that expansions from OS env work as expected.""" environ["extra_str"] = "os-env-extra_str" environ["string"] = "shouldn't be used" @@ -117,9 +109,7 @@ def test_os_env_expansion(source_template_dict: Dict[str, Any]) -> None: def test_from_extras_expansion(source_template_dict: Dict[str, Any]) -> None: - """ - Test that - """ + """Test that.""" extra_source_dict = { "extra_str": "str-from-extras", "string": "shouldn't be used", diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index ac0b942167..667a31d69d 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests helpers for mlos_bench.environments. -""" +"""Tests helpers for mlos_bench.environments.""" from datetime import datetime from typing import Any, Dict, List, Optional, Tuple @@ -55,8 +53,8 @@ def check_env_success(env: Environment, def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: """ - Set up a local environment and run a test experiment there; - Make sure the environment `.status()` call fails. + Set up a local environment and run a test experiment there; Make sure the + environment `.status()` call fails. Parameters ---------- diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 8afb8e5cda..52bea41524 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for base environment class functionality. -""" +"""Unit tests for base environment class functionality.""" from typing import Dict @@ -25,40 +23,30 @@ def test_expand_groups() -> None: - """ - Check the dollar variable expansion for tunable groups. - """ + """Check the dollar variable expansion for tunable groups.""" assert Environment._expand_groups( ["begin", "$list", "$empty", "$str", "end"], _GROUPS) == ["begin", "c", "d", "efg", "end"] def test_expand_groups_empty_input() -> None: - """ - Make sure an empty group stays empty. - """ + """Make sure an empty group stays empty.""" assert Environment._expand_groups([], _GROUPS) == [] def test_expand_groups_empty_list() -> None: - """ - Make sure an empty group expansion works properly. - """ + """Make sure an empty group expansion works properly.""" assert not Environment._expand_groups(["$empty"], _GROUPS) def test_expand_groups_unknown() -> None: - """ - Make sure we fail on unknown $GROUP names expansion. - """ + """Make sure we fail on unknown $GROUP names expansion.""" with pytest.raises(KeyError): Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS) def test_expand_const_args() -> None: - """ - Test expansion of const args via expand_vars. - """ + """Test expansion of const args via expand_vars.""" const_args: Dict[str, TunableValue] = { "a": "b", "foo": "$bar/baz", diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index 6497eb6985..c6c6fff78f 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Check how the services get inherited and overridden in child environments. -""" +"""Check how the services get inherited and overridden in child environments.""" import os import pytest @@ -20,9 +18,7 @@ @pytest.fixture def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: - """ - Test fixture for CompositeEnv with services included on multiple levels. - """ + """Test fixture for CompositeEnv with services included on multiple levels.""" return CompositeEnv( name="Root", config={ @@ -58,9 +54,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: def test_composite_services(composite_env: CompositeEnv) -> None: - """ - Check that each environment gets its own instance of the services. - """ + """Check that each environment gets its own instance of the services.""" for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 742eaf3c79..0f2669e85a 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for composite environment. -""" +"""Unit tests for composite environment.""" import pytest @@ -17,9 +15,7 @@ @pytest.fixture def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: - """ - Test fixture for CompositeEnv. - """ + """Test fixture for CompositeEnv.""" return CompositeEnv( name="Composite Test Environment", config={ @@ -86,7 +82,9 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: def test_composite_env_params(composite_env: CompositeEnv) -> None: """ - Check that the const_args from the parent environment get propagated to the children. + Check that the const_args from the parent environment get propagated to the + children. + NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { @@ -115,9 +113,7 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: - """ - Check that the child environments update their tunable parameters. - """ + """Check that the child environments update their tunable parameters.""" tunable_groups.assign({ "vmSize": "Standard_B2s", "idle": "mwait", @@ -153,9 +149,7 @@ def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: Tunabl @pytest.fixture def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: - """ - Test fixture for CompositeEnv. - """ + """Test fixture for CompositeEnv.""" return CompositeEnv( name="Composite Test Environment", config={ @@ -239,7 +233,9 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None: """ - Check that the const_args from the parent environment get propagated to the children. + Check that the const_args from the parent environment get propagated to the + children. + NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert isinstance(nested_composite_env.children[0], CompositeEnv) @@ -263,9 +259,7 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: - """ - Check that the child environments update their tunable parameters. - """ + """Check that the child environments update their tunable parameters.""" tunable_groups.assign({ "vmSize": "Standard_B2s", "idle": "mwait", diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index 7395aa3e15..0450dfa44d 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test the selection of tunables / tunable groups for the environment. -""" +"""Test the selection of tunables / tunable groups for the environment.""" from mlos_bench.environments.mock_env import MockEnv from mlos_bench.services.config_persistence import ConfigPersistenceService @@ -12,9 +10,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: - """ - Make sure only one tunable group is available to the environment. - """ + """Make sure only one tunable group is available to the environment.""" env = MockEnv( name="Test Env", config={"tunable_params": ["provision"]}, @@ -26,9 +22,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: def test_two_groups(tunable_groups: TunableGroups) -> None: - """ - Make sure only the selected tunable groups are available to the environment. - """ + """Make sure only the selected tunable groups are available to the environment.""" env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, @@ -42,9 +36,8 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: def test_two_groups_setup(tunable_groups: TunableGroups) -> None: - """ - Make sure only the selected tunable groups are available to the environment, - the set is not changed after calling the `.setup()` method. + """Make sure only the selected tunable groups are available to the environment, the + set is not changed after calling the `.setup()` method. """ env = MockEnv( name="Test Env", @@ -77,9 +70,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: - """ - Make sure that no tunable groups are available to the environment by default. - """ + """Make sure that no tunable groups are available to the environment by default.""" env = MockEnv( name="Test Env", config={}, @@ -89,9 +80,8 @@ def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: - """ - Make sure that no tunable groups are available to the environment - when explicitly specifying an empty list of tunable_params. + """Make sure that no tunable groups are available to the environment when explicitly + specifying an empty list of tunable_params. """ env = MockEnv( name="Test Env", @@ -102,9 +92,8 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: - """ - Make sure that no tunable groups are available to the environment by default - and it does not change after the setup. + """Make sure that no tunable groups are available to the environment by default and + it does not change after the setup. """ env = MockEnv( name="Test Env", @@ -130,9 +119,8 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: def test_loader_level_include() -> None: - """ - Make sure only the selected tunable groups are available to the environment, - the set is not changed after calling the `.setup()` method. + """Make sure only the selected tunable groups are available to the environment, the + set is not changed after calling the `.setup()` method. """ env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index 5d8fc32c6b..4ef31ec299 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -4,6 +4,7 @@ # """ Tests for mlos_bench.environments.local. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 9bcb7aa218..f8a8271a7f 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for the composition of several LocalEnv benchmark environments. -""" +"""Unit tests for the composition of several LocalEnv benchmark environments.""" import sys from datetime import datetime, timedelta, tzinfo from typing import Optional @@ -28,8 +26,9 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: """ - Produce benchmark and telemetry data in TWO local environments - and combine the results. + Produce benchmark and telemetry data in TWO local environments and combine the + results. + Also checks that global configs flow down at least one level of CompositeEnv to its children without being explicitly specified in the CompositeEnv so they can be used in the shell_envs by its children. diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index 20854b9f9e..fcdd9b1eab 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for extracting data from LocalEnv stdout. -""" +"""Unit tests for extracting data from LocalEnv stdout.""" import sys @@ -14,9 +12,7 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: - """ - Print benchmark results to stdout and capture them in the LocalEnv. - """ + """Print benchmark results to stdout and capture them in the LocalEnv.""" local_env = create_local_env(tunable_groups, { "run": [ "echo 'Benchmark results:'", # This line should be ignored @@ -42,9 +38,7 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: - """ - Print benchmark results to stdout and capture them in the LocalEnv. - """ + """Print benchmark results to stdout and capture them in the LocalEnv.""" local_env = create_local_env(tunable_groups, { "run": [ "echo 'Benchmark results:'", # This line should be ignored @@ -69,8 +63,8 @@ def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: - """ - Print benchmark results to *BOTH* stdout and a file and extract the results from both. + """Print benchmark results to *BOTH* stdout and a file and extract the results from + both. """ local_env = create_local_env(tunable_groups, { "run": [ diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 35bdb39486..6fb2718706 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for telemetry and status of LocalEnv benchmark environment. -""" +"""Unit tests for telemetry and status of LocalEnv benchmark environment.""" from datetime import datetime, timedelta, tzinfo from typing import Optional @@ -26,9 +24,7 @@ def _format_str(zone_info: Optional[tzinfo]) -> str: # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Produce benchmark and telemetry data in a local script and read it. - """ + """Produce benchmark and telemetry data in a local script and read it.""" ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) @@ -73,9 +69,7 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Read the telemetry data with no header. - """ + """Read the telemetry data with no header.""" ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) @@ -109,9 +103,7 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: @pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Read the telemetry data with incorrect header. - """ + """Read the telemetry data with incorrect header.""" ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second ts2 = ts1 + timedelta(minutes=1) @@ -136,9 +128,7 @@ def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_in def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: - """ - Fail when the telemetry data has wrong format. - """ + """Fail when the telemetry data has wrong format.""" zone_info = UTC ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second @@ -163,9 +153,7 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: - """ - Fail when the telemetry data has wrong format. - """ + """Fail when the telemetry data has wrong format.""" local_env = create_local_env(tunable_groups, { "run": [ # Error: field 1 must be a timestamp diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 6cb4fd4f7e..5ba125c028 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for LocalEnv benchmark environment. -""" +"""Unit tests for LocalEnv benchmark environment.""" import pytest from mlos_bench.tests.environments import check_env_success @@ -13,9 +11,7 @@ def test_local_env(tunable_groups: TunableGroups) -> None: - """ - Produce benchmark and telemetry data in a local script and read it. - """ + """Produce benchmark and telemetry data in a local script and read it.""" local_env = create_local_env(tunable_groups, { "run": [ "echo 'metric,value' > output.csv", @@ -38,8 +34,8 @@ def test_local_env(tunable_groups: TunableGroups) -> None: def test_local_env_service_context(tunable_groups: TunableGroups) -> None: - """ - Basic check that context support for Service mixins are handled when environment contexts are entered. + """Basic check that context support for Service mixins are handled when environment + contexts are entered. """ local_env = create_local_env(tunable_groups, { "run": ["echo NA"] @@ -60,9 +56,7 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: - """ - Fail if the results are not in the expected format. - """ + """Fail if the results are not in the expected format.""" local_env = create_local_env(tunable_groups, { "run": [ # No header @@ -80,9 +74,7 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: def test_local_env_wide(tunable_groups: TunableGroups) -> None: - """ - Produce benchmark data in wide format and read it. - """ + """Produce benchmark data in wide format and read it.""" local_env = create_local_env(tunable_groups, { "run": [ "echo 'latency,throughput,score' > output.csv", diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index c16eac4459..16fd53959c 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for passing shell environment variables into LocalEnv scripts. -""" +"""Unit tests for passing shell environment variables into LocalEnv scripts.""" import sys import pytest @@ -15,9 +13,7 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None: - """ - Check that LocalEnv can set shell environment variables. - """ + """Check that LocalEnv can set shell environment variables.""" local_env = create_local_env(tunable_groups, { "const_args": { "const_arg": 111, # Passed into "shell_env_params" @@ -40,9 +36,7 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: d @pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: - """ - Check that LocalEnv can set shell environment variables in sh-like shell. - """ + """Check that LocalEnv can set shell environment variables in sh-like shell.""" _run_local_env( tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", @@ -57,8 +51,8 @@ def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: @pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: - """ - Check that LocalEnv can set shell environment variables on Windows / cmd shell. + """Check that LocalEnv can set shell environment variables on Windows / cmd + shell. """ _run_local_env( tunable_groups, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 8bce053f7b..8f703c1d01 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for passing shell environment variables into LocalEnv scripts. -""" +"""Unit tests for passing shell environment variables into LocalEnv scripts.""" import pytest from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv @@ -20,9 +18,7 @@ @pytest.fixture(scope="module") def mock_fileshare_service() -> MockFileShareService: - """ - Create a new mock FileShareService instance. - """ + """Create a new mock FileShareService instance.""" return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, parent=LocalExecService(parent=ConfigPersistenceService()) @@ -32,9 +28,7 @@ def mock_fileshare_service() -> MockFileShareService: @pytest.fixture def local_fileshare_env(tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: - """ - Create a LocalFileShareEnv instance. - """ + """Create a LocalFileShareEnv instance.""" env = LocalFileShareEnv( name="TestLocalFileShareEnv", config={ @@ -76,9 +70,8 @@ def local_fileshare_env(tunable_groups: TunableGroups, def test_local_fileshare_env(tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService, local_fileshare_env: LocalFileShareEnv) -> None: - """ - Test that the LocalFileShareEnv correctly expands the `$VAR` variables - in the upload and download sections of the config. + """Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the + upload and download sections of the config. """ with local_fileshare_env as env_context: assert env_context.setup(tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index 608edbf9ef..b055f4f6aa 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mock benchmark environment. -""" +"""Unit tests for mock benchmark environment.""" import pytest from mlos_bench.environments.mock_env import MockEnv @@ -12,9 +10,7 @@ def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> None: - """ - Check the default values of the mock environment. - """ + """Check the default values of the mock environment.""" with mock_env as env_context: assert env_context.setup(tunable_groups) (status, _ts, data) = env_context.run() @@ -29,9 +25,7 @@ def test_mock_env_default(mock_env: MockEnv, tunable_groups: TunableGroups) -> N def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups) -> None: - """ - Check the default values of the mock environment. - """ + """Check the default values of the mock environment.""" with mock_env_no_noise as env_context: assert env_context.setup(tunable_groups) for _ in range(10): @@ -56,9 +50,7 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr ]) def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float) -> None: - """ - Check the benchmark values of the mock environment after the assignment. - """ + """Check the benchmark values of the mock environment after the assignment.""" with mock_env as env_context: tunable_groups.assign(tunable_values) assert env_context.setup(tunable_groups) @@ -83,8 +75,8 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float) -> None: - """ - Check the benchmark values of the noiseless mock environment after the assignment. + """Check the benchmark values of the noiseless mock environment after the + assignment. """ with mock_env_no_noise as env_context: tunable_groups.assign(tunable_values) diff --git a/mlos_bench/mlos_bench/tests/environments/remote/__init__.py b/mlos_bench/mlos_bench/tests/environments/remote/__init__.py index f8a576c536..a72cac05db 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Helpers for RemoteEnv tests. -""" +"""Helpers for RemoteEnv tests.""" diff --git a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py index 4e9e4197e8..257e37fa9e 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/conftest.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Fixtures for the RemoteEnv tests using SSH Services. -""" +"""Fixtures for the RemoteEnv tests using SSH Services.""" import mlos_bench.tests.services.remote.ssh.fixtures as ssh_fixtures diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 878531d799..6fb9dba8c4 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for RemoveEnv benchmark environment via local SSH test services. -""" +"""Unit tests for RemoveEnv benchmark environment via local SSH test services.""" import os import sys @@ -28,9 +26,7 @@ @requires_docker def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: - """ - Produce benchmark and telemetry data in a local script and read it. - """ + """Produce benchmark and telemetry data in a local script and read it.""" global_config: Dict[str, TunableValue] = { "ssh_hostname": ssh_test_server.hostname, "ssh_port": ssh_test_server.get_port(), diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index 377bc940a0..eee8f53304 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.event_loop_context background thread logic. -""" +"""Tests for mlos_bench.event_loop_context background thread logic.""" import asyncio import sys @@ -23,6 +21,7 @@ class EventLoopContextCaller: """ Simple class to test the EventLoopContext. + See Also: SshService """ diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 90aa7e08f7..04750b7c2a 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests to check the launcher and the main optimization loop in-process. -""" +"""Unit tests to check the launcher and the main optimization loop in-process.""" from typing import List @@ -29,9 +27,7 @@ ] ) def test_main_bench(argv: List[str], expected_score: float) -> None: - """ - Run mlos_bench optimization loop with given config and check the results. - """ + """Run mlos_bench optimization loop with given config and check the results.""" (score, _config) = _main(argv) assert score is not None assert pytest.approx(score["score"], 1e-5) == expected_score diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 634050d099..687436d316 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -55,8 +55,9 @@ def config_paths() -> List[str]: def test_launcher_args_parse_1(config_paths: List[str]) -> None: """ - Test that using multiple --globals arguments works and that multiple space - separated options to --config-paths works. + Test that using multiple --globals arguments works and that multiple space separated + options to --config-paths works. + Check $var expansion and Environment loading. """ # The VSCode pytest wrapper actually starts in a different directory before @@ -113,8 +114,7 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: def test_launcher_args_parse_2(config_paths: List[str]) -> None: - """ - Test multiple --config-path instances, --config file vs --arg, --var=val + """Test multiple --config-path instances, --config file vs --arg, --var=val overrides, $var templates, option args, --random-init, etc. """ # The VSCode pytest wrapper actually starts in a different directory before diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 591501d275..04aad14faf 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests to check the main CLI launcher. -""" +"""Unit tests to check the main CLI launcher.""" import os import re from typing import List @@ -20,17 +18,13 @@ @pytest.fixture def root_path() -> str: - """ - Root path of mlos_bench project. - """ + """Root path of mlos_bench project.""" return path_join(os.path.dirname(__file__), "../../..", abs_path=True) @pytest.fixture def local_exec_service() -> LocalExecService: - """ - Test fixture for LocalExecService. - """ + """Test fixture for LocalExecService.""" return LocalExecService(parent=ConfigPersistenceService({ "config_path": [ "mlos_bench/config", @@ -41,9 +35,8 @@ def local_exec_service() -> LocalExecService: def _launch_main_app(root_path: str, local_exec_service: LocalExecService, cli_config: str, re_expected: List[str]) -> None: - """ - Run mlos_bench command-line application with given config - and check the results in the log. + """Run mlos_bench command-line application with given config and check the results + in the log. """ with local_exec_service.temp_dir_context() as temp_dir: @@ -74,9 +67,8 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService, def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecService) -> None: - """ - Run mlos_bench command-line application with mock benchmark config - and default tunable values and check the results in the log. + """Run mlos_bench command-line application with mock benchmark config and default + tunable values and check the results in the log. """ _launch_main_app( root_path, local_exec_service, @@ -92,9 +84,8 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ def test_launch_main_app_bench_values( root_path: str, local_exec_service: LocalExecService) -> None: - """ - Run mlos_bench command-line application with mock benchmark config - and user-specified tunable values and check the results in the log. + """Run mlos_bench command-line application with mock benchmark config and user- + specified tunable values and check the results in the log. """ _launch_main_app( root_path, local_exec_service, @@ -110,9 +101,8 @@ def test_launch_main_app_bench_values( def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecService) -> None: - """ - Run mlos_bench command-line application with mock optimization config - and check the results in the log. + """Run mlos_bench command-line application with mock optimization config and check + the results in the log. """ _launch_main_app( root_path, local_exec_service, diff --git a/mlos_bench/mlos_bench/tests/optimizers/__init__.py b/mlos_bench/mlos_bench/tests/optimizers/__init__.py index 509ecbd842..dbee44936d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/tests/optimizers/__init__.py @@ -4,5 +4,6 @@ # """ Tests for mlos_bench.optimizers. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 59a0fac13b..810f4fcc0e 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test fixtures for mlos_bench optimizers. -""" +"""Test fixtures for mlos_bench optimizers.""" from typing import List @@ -18,9 +16,7 @@ @pytest.fixture def mock_configs() -> List[dict]: - """ - Mock configurations of earlier experiments. - """ + """Mock configurations of earlier experiments.""" return [ { 'vmSize': 'Standard_B4ms', @@ -51,9 +47,7 @@ def mock_configs() -> List[dict]: @pytest.fixture def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: - """ - Test fixture for MockOptimizer that ignores the initial configuration. - """ + """Test fixture for MockOptimizer that ignores the initial configuration.""" return MockOptimizer( tunables=tunable_groups, service=None, @@ -68,9 +62,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: @pytest.fixture def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: - """ - Test fixture for MockOptimizer. - """ + """Test fixture for MockOptimizer.""" return MockOptimizer( tunables=tunable_groups, service=None, @@ -84,9 +76,7 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: @pytest.fixture def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: - """ - Test fixture for MockOptimizer. - """ + """Test fixture for MockOptimizer.""" return MockOptimizer( tunables=tunable_groups, service=None, @@ -100,9 +90,7 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: @pytest.fixture def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - Test fixture for mlos_core FLAML optimizer. - """ + """Test fixture for mlos_core FLAML optimizer.""" return MlosCoreOptimizer( tunables=tunable_groups, service=None, @@ -117,9 +105,7 @@ def flaml_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: @pytest.fixture def flaml_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - Test fixture for mlos_core FLAML optimizer. - """ + """Test fixture for mlos_core FLAML optimizer.""" return MlosCoreOptimizer( tunables=tunable_groups, service=None, @@ -142,9 +128,7 @@ def flaml_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: @pytest.fixture def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - Test fixture for mlos_core SMAC optimizer. - """ + """Test fixture for mlos_core SMAC optimizer.""" return MlosCoreOptimizer( tunables=tunable_groups, service=None, @@ -163,9 +147,7 @@ def smac_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: @pytest.fixture def smac_opt_max(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - Test fixture for mlos_core SMAC optimizer. - """ + """Test fixture for mlos_core SMAC optimizer.""" return MlosCoreOptimizer( tunables=tunable_groups, service=None, diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index cfecb02058..077b2ed058 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for grid search mlos_bench optimizer. -""" +"""Unit tests for grid search mlos_bench optimizer.""" import itertools import math @@ -23,9 +21,7 @@ @pytest.fixture def grid_search_tunables_config() -> dict: - """ - Test fixture for grid search optimizer tunables config. - """ + """Test fixture for grid search optimizer tunables config.""" return { "grid": { "cost": 1, @@ -55,6 +51,7 @@ def grid_search_tunables_config() -> dict: def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. + Used to check that the grids are the same (ignoring order). """ tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] @@ -64,18 +61,14 @@ def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[ @pytest.fixture def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: - """ - Test fixture for grid search optimizer tunables. - """ + """Test fixture for grid search optimizer tunables.""" return TunableGroups(grid_search_tunables_config) @pytest.fixture def grid_search_opt(grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: - """ - Test fixture for grid search optimizer. - """ + """Test fixture for grid search optimizer.""" assert len(grid_search_tunables) == 3 # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. @@ -89,9 +82,7 @@ def grid_search_opt(grid_search_tunables: TunableGroups, def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: - """ - Make sure that grid search optimizer initializes and works correctly. - """ + """Make sure that grid search optimizer initializes and works correctly.""" # Check the size. expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables) assert expected_grid_size > len(grid_search_tunables) @@ -118,9 +109,7 @@ def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: - """ - Make sure that grid search optimizer initializes and works correctly. - """ + """Make sure that grid search optimizer initializes and works correctly.""" score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} status = Status.SUCCEEDED suggestion = grid_search_opt.suggest() @@ -190,8 +179,7 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: - """ - Make sure that grid search optimizer works correctly when suggest and register + """Make sure that grid search optimizer works correctly when suggest and register are called out of order. """ # pylint: disable=too-many-locals @@ -258,9 +246,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: def test_grid_search_register(grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups) -> None: - """ - Make sure that the `.register()` method adjusts the score signs correctly. - """ + """Make sure that the `.register()` method adjusts the score signs correctly.""" assert grid_search_opt.register( grid_search_tunables, Status.SUCCEEDED, { "score": 1.0, diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 6549a8795c..0a0add5b24 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mock mlos_bench optimizer. -""" +"""Unit tests for mock mlos_bench optimizer.""" import pytest @@ -18,9 +16,7 @@ @pytest.fixture def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - Test fixture for mlos_core SMAC optimizer. - """ + """Test fixture for mlos_core SMAC optimizer.""" return MlosCoreOptimizer( tunables=tunable_groups, service=None, @@ -39,16 +35,12 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: @pytest.fixture def mock_scores() -> list: - """ - A list of fake benchmark scores to test the optimizers. - """ + """A list of fake benchmark scores to test the optimizers.""" return [88.88, 66.66, 99.99] def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list) -> None: - """ - Make sure that llamatune+smac optimizer initializes and works correctly. - """ + """Make sure that llamatune+smac optimizer initializes and works correctly.""" for score in mock_scores: assert llamatune_opt.not_converged() tunables = llamatune_opt.suggest() diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index 7ebba0e664..47768f87a4 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for internal methods of the `MlosCoreOptimizer`. -""" +"""Unit tests for internal methods of the `MlosCoreOptimizer`.""" from typing import List @@ -20,9 +18,7 @@ @pytest.fixture def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: - """ - An instance of a mlos_core optimizer (FLAML-based). - """ + """An instance of a mlos_core optimizer (FLAML-based).""" test_opt_config = { 'optimizer_type': 'FLAML', 'max_suggestions': 10, @@ -32,9 +28,7 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> None: - """ - Test `MlosCoreOptimizer._to_df()` method on tunables that have special values. - """ + """Test `MlosCoreOptimizer._to_df()` method on tunables that have special values.""" df_config = mlos_core_optimizer._to_df(mock_configs) assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index fc62b4ff1b..c97c97cf0d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mock mlos_bench optimizer. -""" +"""Unit tests for mock mlos_bench optimizer.""" import os import shutil import sys @@ -22,9 +20,7 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: - """ - Test invalid max_trials initialization of mlos_core SMAC optimizer. - """ + """Test invalid max_trials initialization of mlos_core SMAC optimizer.""" test_opt_config = { 'optimizer_type': 'SMAC', 'max_trials': 10, @@ -37,9 +33,7 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> None: - """ - Test max_trials initialization of mlos_core SMAC optimizer. - """ + """Test max_trials initialization of mlos_core SMAC optimizer.""" test_opt_config = { 'optimizer_type': 'SMAC', 'max_suggestions': 123, @@ -52,8 +46,8 @@ def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> No def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: - """ - Test absolute path output directory initialization of mlos_core SMAC optimizer. + """Test absolute path output directory initialization of mlos_core SMAC + optimizer. """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { @@ -72,8 +66,8 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGroups) -> None: - """ - Test relative path output directory initialization of mlos_core SMAC optimizer. + """Test relative path output directory initialization of mlos_core SMAC + optimizer. """ test_opt_config = { 'optimizer_type': 'SMAC', @@ -90,8 +84,8 @@ def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGr def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: - """ - Test relative path output directory initialization of mlos_core SMAC optimizer. + """Test relative path output directory initialization of mlos_core SMAC + optimizer. """ test_opt_config = { 'optimizer_type': 'SMAC', @@ -109,8 +103,8 @@ def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_gro def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: - """ - Test relative path output directory initialization of mlos_core SMAC optimizer. + """Test relative path output directory initialization of mlos_core SMAC + optimizer. """ test_opt_config = { 'optimizer_type': 'SMAC', @@ -130,9 +124,7 @@ def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunabl def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups) -> None: - """ - Test random output directory initialization of mlos_core SMAC optimizer. - """ + """Test random output directory initialization of mlos_core SMAC optimizer.""" test_opt_config = { 'optimizer_type': 'SMAC', 'output_directory': None, diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index a94a315939..1ce5903306 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mock mlos_bench optimizer. -""" +"""Unit tests for mock mlos_bench optimizer.""" import pytest @@ -16,9 +14,7 @@ @pytest.fixture def mock_configurations_no_defaults() -> list: - """ - A list of 2-tuples of (tunable_values, score) to test the optimizers. - """ + """A list of 2-tuples of (tunable_values, score) to test the optimizers.""" return [ ({ "vmSize": "Standard_B4ms", @@ -43,9 +39,7 @@ def mock_configurations_no_defaults() -> list: @pytest.fixture def mock_configurations(mock_configurations_no_defaults: list) -> list: - """ - A list of 2-tuples of (tunable_values, score) to test the optimizers. - """ + """A list of 2-tuples of (tunable_values, score) to test the optimizers.""" return [ ({ "vmSize": "Standard_B4ms", @@ -57,9 +51,7 @@ def mock_configurations(mock_configurations_no_defaults: list) -> list: def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: - """ - Run several iterations of the optimizer and return the best score. - """ + """Run several iterations of the optimizer and return the best score.""" for (tunable_values, score) in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() @@ -73,34 +65,26 @@ def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> None: - """ - Make sure that mock optimizer produces consistent suggestions. - """ + """Make sure that mock optimizer produces consistent suggestions.""" score = _optimize(mock_opt, mock_configurations) assert score == pytest.approx(66.66, 0.01) def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list) -> None: - """ - Make sure that mock optimizer produces consistent suggestions. - """ + """Make sure that mock optimizer produces consistent suggestions.""" score = _optimize(mock_opt_no_defaults, mock_configurations_no_defaults) assert score == pytest.approx(66.66, 0.01) def test_mock_optimizer_max(mock_opt_max: MockOptimizer, mock_configurations: list) -> None: - """ - Check the maximization mode of the mock optimizer. - """ + """Check the maximization mode of the mock optimizer.""" score = _optimize(mock_opt_max, mock_configurations) assert score == pytest.approx(99.99, 0.01) def test_mock_optimizer_register_fail(mock_opt: MockOptimizer) -> None: - """ - Check the input acceptance conditions for Optimizer.register(). - """ + """Check the input acceptance conditions for Optimizer.register().""" tunables = mock_opt.suggest() mock_opt.register(tunables, Status.SUCCEEDED, {"score": 10}) mock_opt.register(tunables, Status.FAILED) diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index bf37040f13..dd832ce348 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mock mlos_bench optimizer. -""" +"""Unit tests for mock mlos_bench optimizer.""" from typing import Dict, List, Optional @@ -23,6 +21,7 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: """ Same as `mock_config` above, but with all values converted to strings. + (This can happen when we retrieve the data from storage). """ return [ @@ -33,9 +32,7 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: @pytest.fixture def mock_scores() -> List[Optional[Dict[str, TunableValue]]]: - """ - Mock benchmark results from earlier experiments. - """ + """Mock benchmark results from earlier experiments.""" return [ None, {"score": 88.88}, @@ -46,9 +43,7 @@ def mock_scores() -> List[Optional[Dict[str, TunableValue]]]: @pytest.fixture def mock_status() -> List[Status]: - """ - Mock status values for earlier experiments. - """ + """Mock status values for earlier experiments.""" return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] @@ -56,9 +51,7 @@ def _test_opt_update_min(opt: Optimizer, configs: List[dict], scores: List[Optional[Dict[str, TunableValue]]], status: Optional[List[Status]] = None) -> None: - """ - Test the bulk update of the optimizer on the minimization problem. - """ + """Test the bulk update of the optimizer on the minimization problem.""" opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() assert score is not None @@ -76,9 +69,7 @@ def _test_opt_update_max(opt: Optimizer, configs: List[dict], scores: List[Optional[Dict[str, TunableValue]]], status: Optional[List[Status]] = None) -> None: - """ - Test the bulk update of the optimizer on the maximization problem. - """ + """Test the bulk update of the optimizer on the maximization problem.""" opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() assert score is not None @@ -96,9 +87,7 @@ def test_update_mock_min(mock_opt: MockOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the mock optimizer on the minimization problem. - """ + """Test the bulk update of the mock optimizer on the minimization problem.""" _test_opt_update_min(mock_opt, mock_configs, mock_scores, mock_status) # make sure the first suggestion after bulk load is *NOT* the default config: assert mock_opt.suggest().get_param_values() == { @@ -113,9 +102,7 @@ def test_update_mock_min_str(mock_opt: MockOptimizer, mock_configs_str: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the mock optimizer with all-strings data. - """ + """Test the bulk update of the mock optimizer with all-strings data.""" _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) @@ -123,9 +110,7 @@ def test_update_mock_max(mock_opt_max: MockOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the mock optimizer on the maximization problem. - """ + """Test the bulk update of the mock optimizer on the maximization problem.""" _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) @@ -133,9 +118,7 @@ def test_update_flaml(flaml_opt: MlosCoreOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the FLAML optimizer. - """ + """Test the bulk update of the FLAML optimizer.""" _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) @@ -143,9 +126,7 @@ def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the FLAML optimizer. - """ + """Test the bulk update of the FLAML optimizer.""" _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) @@ -153,9 +134,7 @@ def test_update_smac(smac_opt: MlosCoreOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the SMAC optimizer. - """ + """Test the bulk update of the SMAC optimizer.""" _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) @@ -163,7 +142,5 @@ def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, mock_configs: List[dict], mock_scores: List[Optional[Dict[str, TunableValue]]], mock_status: List[Status]) -> None: - """ - Test the bulk update of the SMAC optimizer. - """ + """Test the bulk update of the SMAC optimizer.""" _test_opt_update_max(smac_opt_max, mock_configs, mock_scores, mock_status) diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index 2a50f95e8c..c845f87549 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Toy optimization loop to test the optimizers on mock benchmark environment. -""" +"""Toy optimization loop to test the optimizers on mock benchmark environment.""" import logging from typing import Tuple @@ -29,9 +27,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: - """ - Toy optimization loop. - """ + """Toy optimization loop.""" assert opt.not_converged() while opt.not_converged(): @@ -71,9 +67,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimizer) -> None: - """ - Toy optimization loop with mock environment and optimizer. - """ + """Toy optimization loop with mock environment and optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, mock_opt) assert score == pytest.approx(64.9, 0.01) assert tunables.get_param_values() == { @@ -86,9 +80,7 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer) -> None: - """ - Toy optimization loop with mock environment and optimizer. - """ + """Toy optimization loop with mock environment and optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, mock_opt_no_defaults) assert score == pytest.approx(60.97, 0.01) assert tunables.get_param_values() == { @@ -101,9 +93,7 @@ def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer) -> None: - """ - Toy optimization loop with mock environment and FLAML optimizer. - """ + """Toy optimization loop with mock environment and FLAML optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, flaml_opt) assert score == pytest.approx(60.15, 0.01) assert tunables.get_param_values() == { @@ -117,9 +107,7 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, # @pytest.mark.skip(reason="SMAC is not deterministic") def test_smac_optimization_loop(mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer) -> None: - """ - Toy optimization loop with mock environment and SMAC optimizer. - """ + """Toy optimization loop with mock environment and SMAC optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, smac_opt) expected_score = 70.33 expected_tunable_values = { diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index 1971c01799..fa411976e6 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -4,6 +4,7 @@ # """ Tests for mlos_bench.services. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index d6cb869f09..067715f7e4 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for configuration persistence service. -""" +"""Unit tests for configuration persistence service.""" import os import sys @@ -26,9 +24,7 @@ @pytest.fixture def config_persistence_service() -> ConfigPersistenceService: - """ - Test fixture for ConfigPersistenceService. - """ + """Test fixture for ConfigPersistenceService.""" return ConfigPersistenceService({ "config_path": [ "./non-existent-dir/test/foo/bar", # Non-existent config path @@ -41,9 +37,7 @@ def config_persistence_service() -> ConfigPersistenceService: def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: - """ - Check that CWD is in the search path in the correct place. - """ + """Check that CWD is in the search path in the correct place.""" # pylint: disable=protected-access assert config_persistence_service._config_path is not None cwd = path_join(os.getcwd(), abs_path=True) @@ -53,9 +47,7 @@ def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersisten def test_cwd_in_default_search_path() -> None: - """ - Checks that the CWD is prepended to the search path if not explicitly present. - """ + """Checks that the CWD is prepended to the search path if not explicitly present.""" # pylint: disable=protected-access config_persistence_service = ConfigPersistenceService() assert config_persistence_service._config_path is not None @@ -66,9 +58,7 @@ def test_cwd_in_default_search_path() -> None: def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None: - """ - Check if we can actually find a file somewhere in `config_path`. - """ + """Check if we can actually find a file somewhere in `config_path`.""" # pylint: disable=protected-access assert config_persistence_service._config_path is not None assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path @@ -83,9 +73,7 @@ def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None: - """ - Check if we can actually find a file somewhere in `config_path`. - """ + """Check if we can actually find a file somewhere in `config_path`.""" file_path = "tunable-values/tunable-values-example.jsonc" path = config_persistence_service.resolve_path(file_path) assert path.endswith(file_path) @@ -93,9 +81,7 @@ def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> N def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None: - """ - Check if non-existent file resolves without using `config_path`. - """ + """Check if non-existent file resolves without using `config_path`.""" file_path = "foo/non-existent-config.json" path = config_persistence_service.resolve_path(file_path) assert not os.path.exists(path) @@ -103,8 +89,8 @@ def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) def test_load_config(config_persistence_service: ConfigPersistenceService) -> None: - """ - Check if we can successfully load a config file located relative to `config_path`. + """Check if we can successfully load a config file located relative to + `config_path`. """ tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES) diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index c6dbf7c021..01f6e04dcf 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -4,6 +4,7 @@ # """ Tests for mlos_bench.services.local. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index 572195dcc5..e1da1105b0 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for LocalExecService to run Python scripts locally. -""" +"""Unit tests for LocalExecService to run Python scripts locally.""" import json from typing import Any, Dict @@ -21,16 +19,12 @@ @pytest.fixture def local_exec_service() -> LocalExecService: - """ - Test fixture for LocalExecService. - """ + """Test fixture for LocalExecService.""" return LocalExecService(parent=ConfigPersistenceService()) def test_run_python_script(local_exec_service: LocalExecService) -> None: - """ - Run a Python script using a local_exec service. - """ + """Run a Python script using a local_exec service.""" input_file = "./input-params.json" meta_file = "./input-params-meta.json" output_file = "./config-kernel.sh" diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index bd5b3b7d7f..572c332282 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for the service to run the scripts locally. -""" +"""Unit tests for the service to run the scripts locally.""" import sys import tempfile @@ -21,9 +19,7 @@ def test_split_cmdline() -> None: - """ - Test splitting a commandline into subcommands. - """ + """Test splitting a commandline into subcommands.""" cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" assert list(split_cmdline(cmdline)) == [ ['.', 'env.sh'], @@ -48,9 +44,7 @@ def test_split_cmdline() -> None: @pytest.fixture def local_exec_service() -> LocalExecService: - """ - Test fixture for LocalExecService. - """ + """Test fixture for LocalExecService.""" config = { "abort_on_error": True, } @@ -58,9 +52,7 @@ def local_exec_service() -> LocalExecService: def test_resolve_script(local_exec_service: LocalExecService) -> None: - """ - Test local script resolution logic with complex subcommand names. - """ + """Test local script resolution logic with complex subcommand names.""" script = "os/linux/runtime/scripts/local/generate_kernel_config_script.py" script_abspath = local_exec_service.config_loader_service.resolve_path(script) orig_cmdline = f". env.sh && {script} --input foo" @@ -74,9 +66,7 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: def test_run_script(local_exec_service: LocalExecService) -> None: - """ - Run a script locally and check the results. - """ + """Run a script locally and check the results.""" # `echo` should work on all platforms (return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello"]) assert return_code == 0 @@ -85,9 +75,7 @@ def test_run_script(local_exec_service: LocalExecService) -> None: def test_run_script_multiline(local_exec_service: LocalExecService) -> None: - """ - Run a multiline script locally and check the results. - """ + """Run a multiline script locally and check the results.""" # `echo` should work on all platforms (return_code, stdout, stderr) = local_exec_service.local_exec([ "echo hello", @@ -99,9 +87,7 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: - """ - Run a multiline script locally and pass the environment variables to it. - """ + """Run a multiline script locally and pass the environment variables to it.""" # `echo` should work on all platforms (return_code, stdout, stderr) = local_exec_service.local_exec([ r"echo $var", # Unix shell @@ -116,9 +102,7 @@ def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: - """ - Run a script locally and read the resulting CSV file. - """ + """Run a script locally and read the resulting CSV file.""" with local_exec_service.temp_dir_context() as temp_dir: (return_code, stdout, stderr) = local_exec_service.local_exec([ @@ -143,9 +127,7 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None: - """ - Write data a temp location and run a script that updates it there. - """ + """Write data a temp location and run a script that updates it there.""" with local_exec_service.temp_dir_context() as temp_dir: input_file = "input.txt" @@ -166,18 +148,14 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None def test_run_script_fail(local_exec_service: LocalExecService) -> None: - """ - Try to run a non-existent command. - """ + """Try to run a non-existent command.""" (return_code, stdout, _stderr) = local_exec_service.local_exec(["foo_bar_baz hello"]) assert return_code != 0 assert stdout.strip() == "" def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> None: - """ - Try to run a series of commands, one of which fails, and abort early. - """ + """Try to run a series of commands, one of which fails, and abort early.""" (return_code, stdout, _stderr) = local_exec_service.local_exec([ "echo hello", "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", @@ -188,9 +166,7 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> None: - """ - Try to run a series of commands, one of which fails, but let it pass. - """ + """Try to run a series of commands, one of which fails, but let it pass.""" local_exec_service.abort_on_error = False (return_code, stdout, _stderr) = local_exec_service.local_exec([ "echo hello", @@ -205,9 +181,7 @@ def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> No def test_temp_dir_path_expansion() -> None: - """ - Test that we can control the temp_dir path using globals expansion. - """ + """Test that we can control the temp_dir path using globals expansion.""" # Create a temp dir for the test. # Normally this would be a real path set on the CLI or in a global config, # but for test purposes we still want it to be dynamic and cleaned up after diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index eede9383bc..7e8035e6a0 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Mock local services for testing purposes. -""" +"""Mock local services for testing purposes.""" from .mock_local_exec_service import MockLocalExecService diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index db8f0134c4..9582cc62c8 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking local exec. -""" +"""A collection Service functions for mocking local exec.""" import logging from typing import ( @@ -31,9 +29,7 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): - """ - Mock methods for LocalExecService testing. - """ + """Mock methods for LocalExecService testing.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index 835738015b..e1fe7cbc5a 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -15,13 +15,13 @@ @runtime_checkable class SupportsSomeMethod(Protocol): - """Protocol for some_method""" + """Protocol for some_method.""" def some_method(self) -> str: - """some_method""" + """some_method.""" def some_other_method(self) -> str: - """some_other_method""" + """some_other_method.""" class MockServiceBase(Service, SupportsSomeMethod): @@ -43,11 +43,11 @@ def __init__( ])) def some_method(self) -> str: - """some_method""" + """some_method.""" return f"{self}: base.some_method" def some_other_method(self) -> str: - """some_other_method""" + """some_other_method.""" return f"{self}: base.some_other_method" @@ -57,5 +57,5 @@ class MockServiceChild(MockServiceBase, SupportsSomeMethod): # Intentionally includes no constructor. def some_method(self) -> str: - """some_method""" + """some_method.""" return f"{self}: child.some_method" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index e8a87ab684..137ea2e888 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -4,6 +4,7 @@ # """ Tests for mlos_bench.services.remote. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py index 9bf6e49541..d45db2383e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests helpers for mlos_bench.services.remote.azure. -""" +"""Tests helpers for mlos_bench.services.remote.azure.""" import json from io import BytesIO @@ -12,9 +10,7 @@ def make_httplib_json_response(status: int, json_data: dict) -> urllib3.HTTPResponse: - """ - Prepare a json response object for use with urllib3 - """ + """Prepare a json response object for use with urllib3.""" data = json.dumps(json_data).encode("utf-8") response = urllib3.HTTPResponse( status=status, diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index c6475e6936..6d54389264 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.azure.azure_fileshare -""" +"""Tests for mlos_bench.services.remote.azure.azure_fileshare.""" import os from unittest.mock import MagicMock, Mock, call, patch @@ -142,7 +140,7 @@ def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshar class MyDirEntry: # pylint: disable=too-few-public-methods - """Dummy class for os.DirEntry""" + """Dummy class for os.DirEntry.""" def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index d6d55d3975..67fc9d56fb 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.azure.azure_network_services -""" +"""Tests for mlos_bench.services.remote.azure.azure_network_services.""" from unittest.mock import MagicMock, patch @@ -28,9 +26,7 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock, total_retries: int, operation_status: Status, azure_network_service: AzureNetworkService) -> None: - """ - Test retries of the network deployment operation. - """ + """Test retries of the network deployment operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ @@ -73,9 +69,7 @@ def test_network_operation_status(mock_requests: MagicMock, accepts_params: bool, http_status_code: int, operation_status: Status) -> None: - """ - Test network operation status. - """ + """Test network operation status.""" mock_response = MagicMock() mock_response.status_code = http_status_code mock_requests.post.return_value = mock_response @@ -90,9 +84,7 @@ def test_network_operation_status(mock_requests: MagicMock, @pytest.fixture def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: - """ - Tests creating a network services without a deployment template (should fail). - """ + """Tests creating a network services without a deployment template (should fail).""" with pytest.raises(ValueError): _ = AzureNetworkService(config={ "deploymentTemplatePath": None, diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 1d84d73cab..6b1235f3f7 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.azure.azure_vm_services -""" +"""Tests for mlos_bench.services.remote.azure.azure_vm_services.""" from copy import deepcopy from unittest.mock import MagicMock, patch @@ -29,9 +27,7 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, total_retries: int, operation_status: Status, azure_vm_service: AzureVMService) -> None: - """ - Test retries of the host deployment operation. - """ + """Test retries of the host deployment operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ @@ -55,9 +51,7 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: - """ - Test expanding template params recursively. - """ + """Test expanding template params recursively.""" config = { "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", "subscription": "TEST_SUB1", @@ -80,9 +74,7 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: - """ - Test loading custom data from a file. - """ + """Test loading custom data from a file.""" config = { "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml", "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", @@ -129,9 +121,7 @@ def test_vm_operation_status(mock_requests: MagicMock, accepts_params: bool, http_status_code: int, operation_status: Status) -> None: - """ - Test VM operation status. - """ + """Test VM operation status.""" mock_response = MagicMock() mock_response.status_code = http_status_code mock_requests.post.return_value = mock_response @@ -151,9 +141,7 @@ def test_vm_operation_status(mock_requests: MagicMock, def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, operation_name: str, accepts_params: bool) -> None: - """ - Test VM operation status for an incomplete service config. - """ + """Test VM operation status for an incomplete service config.""" operation = getattr(azure_vm_service_remote_exec_only, operation_name) with pytest.raises(ValueError): (_, _) = operation({"vmName": "test-vm"}) if accepts_params else operation() @@ -163,9 +151,7 @@ def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService) -> None: - """ - Test waiting for the completion of the remote VM operation. - """ + """Test waiting for the completion of the remote VM operation.""" # Mock response header async_url = "DUMMY_ASYNC_URL" retry_after = 12345 @@ -191,9 +177,7 @@ def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") def test_wait_vm_operation_timeout(mock_session: MagicMock, azure_vm_service: AzureVMService) -> None: - """ - Test the time out of the remote VM operation. - """ + """Test the time out of the remote VM operation.""" # Mock response header params = { "asyncResultsUrl": "DUMMY_ASYNC_URL", @@ -222,9 +206,7 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, total_retries: int, operation_status: Status, azure_vm_service: AzureVMService) -> None: - """ - Test the retries of the remote VM operation. - """ + """Test the retries of the remote VM operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ @@ -255,9 +237,7 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, http_status_code: int, operation_status: Status) -> None: - """ - Test waiting for completion of the remote execution on Azure. - """ + """Test waiting for completion of the remote execution on Azure.""" script = ["command_1", "command_2"] mock_response = MagicMock() @@ -275,9 +255,7 @@ def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_ex @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") def test_remote_exec_headers_output(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService) -> None: - """ - Check if HTTP headers from the remote execution on Azure are correct. - """ + """Check if HTTP headers from the remote execution on Azure are correct.""" async_url_key = "asyncResultsUrl" async_url_value = "DUMMY_ASYNC_URL" script = ["command_1", "command_2"] @@ -330,9 +308,7 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, ]) def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, wait_output: dict, results_output: dict) -> None: - """ - Test getting the results of the remote execution on Azure. - """ + """Test getting the results of the remote execution on Azure.""" params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} mock_wait_host_operation = MagicMock() diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 2794bb01cf..14a18f94ef 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Configuration test fixtures for azure_vm_services in mlos_bench. -""" +"""Configuration test fixtures for azure_vm_services in mlos_bench.""" from unittest.mock import patch @@ -23,18 +21,14 @@ @pytest.fixture def config_persistence_service() -> ConfigPersistenceService: - """ - Test fixture for ConfigPersistenceService. - """ + """Test fixture for ConfigPersistenceService.""" return ConfigPersistenceService() @pytest.fixture def azure_auth_service(config_persistence_service: ConfigPersistenceService, monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: - """ - Creates a dummy AzureAuthService for tests that require it. - """ + """Creates a dummy AzureAuthService for tests that require it.""" auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service) monkeypatch.setattr(auth, "get_access_token", lambda: "TEST_TOKEN") return auth @@ -42,9 +36,7 @@ def azure_auth_service(config_persistence_service: ConfigPersistenceService, @pytest.fixture def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkService: - """ - Creates a dummy Azure VM service for tests that require it. - """ + """Creates a dummy Azure VM service for tests that require it.""" return AzureNetworkService(config={ "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", "subscription": "TEST_SUB", @@ -62,9 +54,7 @@ def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkS @pytest.fixture def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: - """ - Creates a dummy Azure VM service for tests that require it. - """ + """Creates a dummy Azure VM service for tests that require it.""" return AzureVMService(config={ "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", "subscription": "TEST_SUB", @@ -82,9 +72,7 @@ def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: @pytest.fixture def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> AzureVMService: - """ - Creates a dummy Azure VM service with no deployment template. - """ + """Creates a dummy Azure VM service with no deployment template.""" return AzureVMService(config={ "subscription": "TEST_SUB", "resourceGroup": "TEST_RG", @@ -97,9 +85,7 @@ def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> A @pytest.fixture def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: - """ - Creates a dummy AzureFileShareService for tests that require it. - """ + """Creates a dummy AzureFileShareService for tests that require it.""" with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): return AzureFileShareService(config={ "storageAccountName": "TEST_ACCOUNT_NAME", diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py index d86cbdf2a3..a12bde8a23 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Mock remote services for testing purposes. -""" +"""Mock remote services for testing purposes.""" from typing import Any, Tuple diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index b9474f0709..9f75d79eac 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking authentication. -""" +"""A collection Service functions for mocking authentication.""" import logging from typing import Any, Callable, Dict, List, Optional, Union @@ -16,9 +14,7 @@ class MockAuthService(Service, SupportsAuth): - """ - A collection Service functions for mocking authentication ops. - """ + """A collection Service functions for mocking authentication ops.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 1a026966a8..5378e12837 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking file share ops. -""" +"""A collection Service functions for mocking file share ops.""" import logging from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -17,9 +15,7 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): - """ - A collection Service functions for mocking file share ops. - """ + """A collection Service functions for mocking file share ops.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, @@ -39,13 +35,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: - """ - Get the list of files that were uploaded. - """ + """Get the list of files that were uploaded.""" return self._upload def get_download(self) -> List[Tuple[str, str]]: - """ - Get the list of files that were downloaded. - """ + """Get the list of files that were downloaded.""" return self._download diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index e6169d9f93..03a02ba14e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking managing (Virtual) Networks. -""" +"""A collection Service functions for mocking managing (Virtual) Networks.""" from typing import Any, Callable, Dict, List, Optional, Union @@ -16,9 +14,7 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): - """ - Mock Network service for testing. - """ + """Mock Network service for testing.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index ee99251c64..f1e29e5cd4 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking remote script execution. -""" +"""A collection Service functions for mocking remote script execution.""" from typing import Any, Callable, Dict, List, Optional, Union @@ -14,9 +12,7 @@ class MockRemoteExecService(Service, SupportsRemoteExec): - """ - Mock remote script execution service. - """ + """Mock remote script execution service.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index a44edaf080..1fe659a23f 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -A collection Service functions for mocking managing VMs. -""" +"""A collection Service functions for mocking managing VMs.""" from typing import Any, Callable, Dict, List, Optional, Union @@ -16,9 +14,7 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps): - """ - Mock VM service for testing. - """ + """Mock VM service for testing.""" def __init__(self, config: Optional[Dict[str, Any]] = None, global_config: Optional[Dict[str, Any]] = None, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index e0060d8047..9d5e0ef153 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common data classes for the SSH service tests. -""" +"""Common data classes for the SSH service tests.""" from dataclasses import dataclass from subprocess import run @@ -24,9 +22,7 @@ @dataclass class SshTestServerInfo: - """ - A data class for SshTestServerInfo. - """ + """A data class for SshTestServerInfo.""" compose_project_name: str service_name: str @@ -59,6 +55,7 @@ def to_ssh_service_config(self, uncached: bool = False) -> dict: def to_connect_params(self, uncached: bool = False) -> dict: """ Convert to a connect_params dict for SshClient. + See Also: mlos_bench.services.remote.ssh.ssh_service.SshService._get_connect_params() """ return { diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py index 1bb910ed77..34006985af 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Fixtures for the SSH service tests. -""" +"""Fixtures for the SSH service tests.""" import mlos_bench.tests.services.remote.ssh.fixtures as ssh_fixtures diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 6f05fe953b..28c0367afa 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -51,14 +51,13 @@ def ssh_test_server(ssh_test_server_hostname: str, docker_compose_project_name: str, locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: """ - Fixture for getting the ssh test server services setup via docker-compose - using pytest-docker. + Fixture for getting the ssh test server services setup via docker-compose using + pytest-docker. Yields the (hostname, port, username, id_rsa_path) of the test server. - Once the session is over, the docker containers are torn down, and the - temporary file holding the dynamically generated private key of the test - server is deleted. + Once the session is over, the docker containers are torn down, and the temporary + file holding the dynamically generated private key of the test server is deleted. """ # Get a copy of the ssh id_rsa key from the test ssh server. with tempfile.NamedTemporaryFile() as id_rsa_file: @@ -85,6 +84,7 @@ def alt_test_server(ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. + See additional notes in the ssh_test_server fixture above. """ # Note: The alt-server uses the same image as the ssh-server container, so @@ -105,6 +105,7 @@ def reboot_test_server(ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. + See additional notes in the ssh_test_server fixture above. """ # Note: The reboot-server uses the same image as the ssh-server container, so diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index f2bbbe4b8a..7b1a4e0756 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.ssh.ssh_services -""" +"""Tests for mlos_bench.services.remote.ssh.ssh_services.""" import os import tempfile @@ -26,8 +24,8 @@ @contextmanager def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, None]: """ - Provides a context manager for a temporary file that can be closed and - still unlinked. + Provides a context manager for a temporary file that can be closed and still + unlinked. Since Windows doesn't allow us to reopen the file while it's still open we need to handle deletion ourselves separately. diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 4c8e5e0c66..ab7f3cd9e0 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.ssh.ssh_host_service -""" +"""Tests for mlos_bench.services.remote.ssh.ssh_host_service.""" import logging import time @@ -33,8 +31,8 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, """ Test the SshHostService remote_exec. - This checks state of the service across multiple invocations and states to - check for internal cache handling logic as well. + This checks state of the service across multiple invocations and states to check for + internal cache handling logic as well. """ # pylint: disable=protected-access with ssh_host_service: @@ -138,9 +136,7 @@ def check_ssh_service_reboot(docker_services: DockerServices, reboot_test_server: SshTestServerInfo, ssh_host_service: SshHostService, graceful: bool) -> None: - """ - Check the SshHostService reboot operation. - """ + """Check the SshHostService reboot operation.""" # Note: rebooting changes the port number unfortunately, but makes it # easier to check for success. # Also, it may cause issues with other parallel unit tests, so we run it as @@ -211,9 +207,7 @@ def check_ssh_service_reboot(docker_services: DockerServices, def test_ssh_service_reboot(locked_docker_services: DockerServices, reboot_test_server: SshTestServerInfo, ssh_host_service: SshHostService) -> None: - """ - Test the SshHostService reboot operation. - """ + """Test the SshHostService reboot operation.""" # Grouped together to avoid parallel runner interactions. check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index 7bee929fea..1eabd7ea37 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench.services.remote.ssh.SshService base class. -""" +"""Tests for mlos_bench.services.remote.ssh.SshService base class.""" import asyncio import time @@ -71,6 +69,7 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. + See Also: test_event_loop_context """ # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py index 463879634f..088223279b 100644 --- a/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py +++ b/mlos_bench/mlos_bench/tests/services/test_service_method_registering.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for Service method registering. -""" +"""Unit tests for Service method registering.""" import pytest @@ -17,9 +15,7 @@ def test_service_method_register_without_constructor() -> None: - """ - Test registering a method without a constructor. - """ + """Test registering a method without a constructor.""" # pylint: disable=protected-access some_base_service = MockServiceBase() some_child_service = MockServiceChild() diff --git a/mlos_bench/mlos_bench/tests/storage/__init__.py b/mlos_bench/mlos_bench/tests/storage/__init__.py index ca5a3b33dd..c3b294cae1 100644 --- a/mlos_bench/mlos_bench/tests/storage/__init__.py +++ b/mlos_bench/mlos_bench/tests/storage/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench storage. -""" +"""Tests for mlos_bench storage.""" CONFIG_COUNT = 10 CONFIG_TRIAL_REPEAT_COUNT = 3 diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 2c16df65c4..879be9497a 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Export test fixtures for mlos_bench storage. -""" +"""Export test fixtures for mlos_bench storage.""" import mlos_bench.tests.storage.sql.fixtures as sql_storage_fixtures diff --git a/mlos_bench/mlos_bench/tests/storage/exp_context_test.py b/mlos_bench/mlos_bench/tests/storage/exp_context_test.py index e2b1d7c26b..f0bfa1d127 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_context_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_context_test.py @@ -2,16 +2,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for the storage subsystem. -""" +"""Unit tests for the storage subsystem.""" from mlos_bench.storage.base_storage import Storage def test_exp_context(exp_storage: Storage.Experiment) -> None: - """ - Try to retrieve old experimental data from the empty storage. - """ + """Try to retrieve old experimental data from the empty storage.""" # pylint: disable=protected-access assert exp_storage._in_context diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 8159043be1..941683333e 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for loading the experiment metadata. -""" +"""Unit tests for loading the experiment metadata.""" from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.storage.base_storage import Storage @@ -13,9 +11,7 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) -> None: - """ - Try to retrieve old experimental data from the empty storage. - """ + """Try to retrieve old experimental data from the empty storage.""" exp = storage.experiments[exp_storage.experiment_id] assert exp.experiment_id == exp_storage.experiment_id assert exp.description == exp_storage.description @@ -23,7 +19,7 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: - """Tests the root_env_config property of ExperimentData""" + """Tests the root_env_config property of ExperimentData.""" # pylint: disable=protected-access assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) @@ -31,9 +27,7 @@ def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: Exp def test_exp_trial_data_objectives(storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: - """ - Start a new trial and check the storage for the trial data. - """ + """Start a new trial and check the storage for the trial data.""" trial_opt_new = exp_storage.new_trial(tunable_groups, config={ "opt_target": "some-other-target", @@ -67,7 +61,7 @@ def test_exp_trial_data_objectives(storage: Storage, def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: - """Tests the results_df property of ExperimentData""" + """Tests the results_df property of ExperimentData.""" results_df = exp_data.results_df expected_trials_count = CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_trials_count @@ -81,7 +75,7 @@ def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGr def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: """ - Tests the tunable_config_trial_group_id property of ExperimentData.results_df + Tests the tunable_config_trial_group_id property of ExperimentData.results_df. See Also: test_exp_trial_data_tunable_config_trial_group_id() """ @@ -109,7 +103,7 @@ def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: Experime def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: """ - Tests the tunable_config_trial_groups property of ExperimentData + Tests the tunable_config_trial_groups property of ExperimentData. This tests bulk loading of the tunable_config_trial_groups. """ @@ -126,7 +120,7 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: - """Tests the tunable_configs property of ExperimentData""" + """Tests the tunable_configs property of ExperimentData.""" # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. @@ -136,5 +130,5 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: - """Tests the default_tunable_config_id property of ExperimentData""" + """Tests the default_tunable_config_id property of ExperimentData.""" assert exp_data.default_tunable_config_id == 1 diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index d0a5edc694..d69a580b9e 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for the storage subsystem. -""" +"""Unit tests for the storage subsystem.""" from datetime import datetime, tzinfo from typing import Optional @@ -18,9 +16,7 @@ def test_exp_load_empty(exp_storage: Storage.Experiment) -> None: - """ - Try to retrieve old experimental data from the empty storage. - """ + """Try to retrieve old experimental data from the empty storage.""" (trial_ids, configs, scores, status) = exp_storage.load() assert not trial_ids assert not configs @@ -29,9 +25,7 @@ def test_exp_load_empty(exp_storage: Storage.Experiment) -> None: def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: - """ - Try to retrieve pending experiments from the empty storage. - """ + """Try to retrieve pending experiments from the empty storage.""" trials = list(exp_storage.pending_trials(datetime.now(UTC), running=True)) assert not trials @@ -40,9 +34,7 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Start a trial and check that it is pending. - """ + """Start a trial and check that it is pending.""" trial = exp_storage.new_trial(tunable_groups) (pending,) = list(exp_storage.pending_trials(datetime.now(zone_info), running=True)) assert pending.trial_id == trial.trial_id @@ -53,9 +45,7 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, def test_exp_trial_pending_many(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Start THREE trials and check that both are pending. - """ + """Start THREE trials and check that both are pending.""" config1 = tunable_groups.copy().assign({'idle': 'mwait'}) config2 = tunable_groups.copy().assign({'idle': 'noidle'}) trial_ids = { @@ -75,9 +65,7 @@ def test_exp_trial_pending_many(exp_storage: Storage.Experiment, def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Start a trial, fail it, and and check that it is NOT pending. - """ + """Start a trial, fail it, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.FAILED, datetime.now(zone_info)) trials = list(exp_storage.pending_trials(datetime.now(zone_info), running=True)) @@ -88,9 +76,7 @@ def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, def test_exp_trial_success(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Start a trial, finish it successfully, and and check that it is NOT pending. - """ + """Start a trial, finish it successfully, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9}) trials = list(exp_storage.pending_trials(datetime.now(zone_info), running=True)) @@ -101,9 +87,7 @@ def test_exp_trial_success(exp_storage: Storage.Experiment, def test_exp_trial_update_categ(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Update the trial with multiple metrics, some of which are categorical. - """ + """Update the trial with multiple metrics, some of which are categorical.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) assert exp_storage.load() == ( @@ -123,9 +107,7 @@ def test_exp_trial_update_categ(exp_storage: Storage.Experiment, def test_exp_trial_update_twice(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: - """ - Update the trial status twice and receive an error. - """ + """Update the trial status twice and receive an error.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.FAILED, datetime.now(zone_info)) with pytest.raises(RuntimeError): @@ -138,6 +120,7 @@ def test_exp_trial_pending_3(exp_storage: Storage.Experiment, zone_info: Optional[tzinfo]) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. + Check that one is still pending another one can be loaded into the optimizer. """ score = 99.9 diff --git a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py index 61b8ec8df4..d17a448b5e 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_bench sql storage. -""" +"""Tests for mlos_bench sql storage.""" diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 7e346a5ccc..839404ff0b 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test fixtures for mlos_bench storage. -""" +"""Test fixtures for mlos_bench storage.""" from datetime import datetime from random import random @@ -27,9 +25,7 @@ @pytest.fixture def storage() -> SqlStorage: - """ - Test fixture for in-memory SQLite3 storage. - """ + """Test fixture for in-memory SQLite3 storage.""" return SqlStorage( service=None, config={ @@ -47,6 +43,7 @@ def exp_storage( ) -> Generator[SqlStorage.Experiment, None, None]: """ Test fixture for Experiment using in-memory SQLite3 storage. + Note: It has already entered the context upon return. """ with storage.experiment( @@ -68,6 +65,7 @@ def exp_no_tunables_storage( ) -> Generator[SqlStorage.Experiment, None, None]: """ Test fixture for Experiment using in-memory SQLite3 storage. + Note: It has already entered the context upon return. """ empty_config: dict = {} @@ -90,7 +88,9 @@ def mixed_numerics_exp_storage( mixed_numerics_tunable_groups: TunableGroups, ) -> Generator[SqlStorage.Experiment, None, None]: """ - Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 storage. + Test fixture for an Experiment with mixed numerics tunables using in-memory SQLite3 + storage. + Note: It has already entered the context upon return. """ with storage.experiment( @@ -107,9 +107,7 @@ def mixed_numerics_exp_storage( def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: - """ - Generates data by doing a simulated run of the given experiment. - """ + """Generates data by doing a simulated run of the given experiment.""" # Add some trials to that experiment. # Note: we're just fabricating some made up function for the ML libraries to try and learn. base_score = 10.0 @@ -160,49 +158,37 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S @pytest.fixture def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: - """ - Test fixture for Experiment using in-memory SQLite3 storage. - """ + """Test fixture for Experiment using in-memory SQLite3 storage.""" return _dummy_run_exp(exp_storage, tunable_name="kernel_sched_latency_ns") @pytest.fixture def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: - """ - Test fixture for Experiment using in-memory SQLite3 storage. - """ + """Test fixture for Experiment using in-memory SQLite3 storage.""" assert not exp_no_tunables_storage.tunables return _dummy_run_exp(exp_no_tunables_storage, tunable_name=None) @pytest.fixture def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: - """ - Test fixture for Experiment using in-memory SQLite3 storage. - """ + """Test fixture for Experiment using in-memory SQLite3 storage.""" tunable = next(iter(mixed_numerics_exp_storage.tunables))[0] return _dummy_run_exp(mixed_numerics_exp_storage, tunable_name=tunable.name) @pytest.fixture def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: - """ - Test fixture for ExperimentData. - """ + """Test fixture for ExperimentData.""" return storage.experiments[exp_storage_with_trials.experiment_id] @pytest.fixture def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: - """ - Test fixture for ExperimentData with no tunable configs. - """ + """Test fixture for ExperimentData with no tunable configs.""" return storage.experiments[exp_no_tunables_storage_with_trials.experiment_id] @pytest.fixture def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: - """ - Test fixture for ExperimentData with mixed numerical tunable types. - """ + """Test fixture for ExperimentData with mixed numerical tunable types.""" return storage.experiments[mixed_numerics_exp_storage_with_trials.experiment_id] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index ba965ed3c6..851993f4a2 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for saving and retrieving additional parameters of pending trials. -""" +"""Unit tests for saving and retrieving additional parameters of pending trials.""" from datetime import datetime from pytz import UTC @@ -15,9 +13,7 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: - """ - Schedule a trial and check that it is pending and has the right configuration. - """ + """Schedule a trial and check that it is pending and has the right configuration.""" config = {"location": "westus2", "num_repeats": 100} trial = exp_storage.new_trial(tunable_groups, config=config) (pending,) = list(exp_storage.pending_trials(datetime.now(UTC), running=True)) @@ -33,9 +29,8 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: - """ - Start multiple trials with two different configs and check that - we store only two config objects in the DB. + """Start multiple trials with two different configs and check that we store only two + config objects in the DB. """ config1 = tunable_groups.copy().assign({'idle': 'mwait'}) trials1 = [ @@ -67,9 +62,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, def test_exp_trial_no_config(exp_no_tunables_storage: Storage.Experiment) -> None: - """ - Schedule a trial that has an empty tunable groups config. - """ + """Schedule a trial that has an empty tunable groups config.""" empty_config: dict = {} tunable_groups = TunableGroups(config=empty_config) trial = exp_no_tunables_storage.new_trial(tunable_groups, config=empty_config) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py index c3703c9a13..9fe59b426b 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_data_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for loading the trial metadata. -""" +"""Unit tests for loading the trial metadata.""" from datetime import datetime @@ -15,9 +13,7 @@ def test_exp_trial_data(exp_data: ExperimentData) -> None: - """ - Check expected return values for TrialData. - """ + """Check expected return values for TrialData.""" trial_id = 1 expected_config_id = 1 trial = exp_data.trials[trial_id] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index 04f4f18ae3..628051a373 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for scheduling trials for some future time. -""" +"""Unit tests for scheduling trials for some future time.""" from datetime import datetime, timedelta from typing import Iterator, Set @@ -16,16 +14,14 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: - """ - Extract trial IDs from a list of trials. - """ + """Extract trial IDs from a list of trials.""" return set(t.trial_id for t in trials) def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: - """ - Schedule several trials for future execution and retrieve them later at certain timestamps. + """Schedule several trials for future execution and retrieve them later at certain + timestamps. """ timestamp = datetime.now(UTC) timedelta_1min = timedelta(minutes=1) diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index 855c6cd861..cffaaac4c6 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for saving and restoring the telemetry data. -""" +"""Unit tests for saving and restoring the telemetry data.""" from datetime import datetime, timedelta, tzinfo from typing import Any, List, Optional, Tuple @@ -43,9 +41,7 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st def _telemetry_str(data: List[Tuple[datetime, str, Any]] ) -> List[Tuple[datetime, str, Optional[str]]]: - """ - Convert telemetry values to strings. - """ + """Convert telemetry values to strings.""" # All retrieved timestamps should have been converted to UTC. return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data] @@ -55,9 +51,7 @@ def test_update_telemetry(storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups, origin_zone_info: Optional[tzinfo]) -> None: - """ - Make sure update_telemetry() and load_telemetry() methods work. - """ + """Make sure update_telemetry() and load_telemetry() methods work.""" telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) assert exp_storage.load_telemetry(trial.trial_id) == [] @@ -76,9 +70,7 @@ def test_update_telemetry(storage: Storage, def test_update_telemetry_twice(exp_storage: Storage.Experiment, tunable_groups: TunableGroups, origin_zone_info: Optional[tzinfo]) -> None: - """ - Make sure update_telemetry() call is idempotent. - """ + """Make sure update_telemetry() call is idempotent.""" telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) timestamp = datetime.now(origin_zone_info) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 3b57222822..ea13f63ea5 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for loading the TunableConfigData. -""" +"""Unit tests for loading the TunableConfigData.""" from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.tunables.tunable_groups import TunableGroups @@ -12,9 +10,7 @@ def test_trial_data_tunable_config_data(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: - """ - Check expected return values for TunableConfigData. - """ + """Check expected return values for TunableConfigData.""" trial_id = 1 expected_config_id = 1 trial = exp_data.trials[trial_id] @@ -26,9 +22,7 @@ def test_trial_data_tunable_config_data(exp_data: ExperimentData, def test_trial_metadata(exp_data: ExperimentData) -> None: - """ - Check expected return values for TunableConfigData metadata. - """ + """Check expected return values for TunableConfigData metadata.""" assert exp_data.objectives == {'score': 'min'} for (trial_id, trial) in exp_data.trials.items(): assert trial.metadata_dict == { @@ -39,9 +33,7 @@ def test_trial_metadata(exp_data: ExperimentData) -> None: def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData) -> None: - """ - Check expected return values for TunableConfigData. - """ + """Check expected return values for TunableConfigData.""" empty_config: dict = {} for _trial_id, trial in exp_no_tunables_data.trials.items(): assert trial.tunable_config.config_dict == empty_config @@ -50,8 +42,7 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( mixed_numerics_exp_data: ExperimentData, mixed_numerics_tunable_groups: TunableGroups) -> None: - """ - Tests that data type conversions are retained when loading experiment data with + """Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index d08b26e92d..0646129e42 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for loading the TunableConfigTrialGroupData. -""" +"""Unit tests for loading the TunableConfigTrialGroupData.""" from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.tests.storage import CONFIG_TRIAL_REPEAT_COUNT diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index fa947610da..cd7edcd005 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests various other test scenarios with alternative default (un-named) TZ info. -""" +"""Tests various other test scenarios with alternative default (un-named) TZ info.""" import os import sys @@ -28,9 +26,7 @@ @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: - """ - Run the tests under alternative default (un-named) TZ info. - """ + """Run the tests under alternative default (un-named) TZ info.""" env = os.environ.copy() if tz_name is None: env.pop("TZ", None) diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 822547b1da..64ab724be8 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common fixtures for mock TunableGroups. -""" +"""Common fixtures for mock TunableGroups.""" from typing import Any, Dict @@ -72,9 +70,7 @@ @pytest.fixture def tunable_groups_config() -> Dict[str, Any]: - """ - Fixture to get the JSON string for the tunable groups. - """ + """Fixture to get the JSON string for the tunable groups.""" conf = json.loads(TUNABLE_GROUPS_JSON) assert isinstance(conf, dict) ConfigSchema.TUNABLE_PARAMS.validate(conf) diff --git a/mlos_bench/mlos_bench/tests/tunables/__init__.py b/mlos_bench/mlos_bench/tests/tunables/__init__.py index 83c046e575..69ef4a9204 100644 --- a/mlos_bench/mlos_bench/tests/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tests/tunables/__init__.py @@ -4,5 +4,6 @@ # """ Tests for mlos_bench.tunables. + Used to make mypy happy about multiple conftest.py modules. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index 95de20d9b8..f5b1629c9f 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test fixtures for individual Tunable objects. -""" +"""Test fixtures for individual Tunable objects.""" import pytest diff --git a/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py b/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py index 0b3e124779..50e9061222 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_empty_tunable_group.py @@ -2,23 +2,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for empty tunable groups. -""" +"""Unit tests for empty tunable groups.""" from mlos_bench.tunables.tunable_groups import TunableGroups def test_empty_tunable_group() -> None: - """ - Test __nonzero__ property of tunable groups. - """ + """Test __nonzero__ property of tunable groups.""" tunable_groups = TunableGroups(config={}) assert not tunable_groups def test_non_empty_tunable_group(tunable_groups: TunableGroups) -> None: - """ - Test __nonzero__ property of tunable groups. - """ + """Test __nonzero__ property of tunable groups.""" assert tunable_groups diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index 0e910f3761..28f92d4769 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -2,17 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for deep copy of tunable objects and groups. -""" +"""Unit tests for deep copy of tunable objects and groups.""" from mlos_bench.tunables.tunable_groups import TunableGroups def test_tunable_categorical_types() -> None: - """ - Check if we accept tunable categoricals as ints as well as strings and - convert both to strings. + """Check if we accept tunable categoricals as ints as well as strings and convert + both to strings. """ tunable_params = { "test-group": { diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index 0181957cd0..768be65cb2 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for checking tunable size properties. -""" +"""Unit tests for checking tunable size properties.""" import numpy as np import pytest @@ -16,7 +14,7 @@ def test_tunable_int_size_props() -> None: - """Test tunable int size properties""" + """Test tunable int size properties.""" tunable = Tunable( name="test", config={ @@ -32,7 +30,7 @@ def test_tunable_int_size_props() -> None: def test_tunable_float_size_props() -> None: - """Test tunable float size properties""" + """Test tunable float size properties.""" tunable = Tunable( name="test", config={ @@ -47,7 +45,7 @@ def test_tunable_float_size_props() -> None: def test_tunable_categorical_size_props() -> None: - """Test tunable categorical size properties""" + """Test tunable categorical size properties.""" tunable = Tunable( name="test", config={ @@ -64,7 +62,7 @@ def test_tunable_categorical_size_props() -> None: def test_tunable_quantized_int_size_props() -> None: - """Test quantized tunable int size properties""" + """Test quantized tunable int size properties.""" tunable = Tunable( name="test", config={ @@ -81,7 +79,7 @@ def test_tunable_quantized_int_size_props() -> None: def test_tunable_quantized_float_size_props() -> None: - """Test quantized tunable float size properties""" + """Test quantized tunable float size properties.""" tunable = Tunable( name="test", config={ diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py index fcf0d5b9e5..8aa888ebca 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_accessors_test.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for accessing values to the individual parameters within tunable groups. +"""Unit tests for accessing values to the individual parameters within tunable +groups. """ import pytest @@ -13,9 +13,7 @@ def test_categorical_access_to_numerical_tunable(tunable_int: Tunable) -> None: - """ - Make sure we throw an error on accessing a numerical tunable as a categorical. - """ + """Make sure we throw an error on accessing a numerical tunable as a categorical.""" with pytest.raises(ValueError): print(tunable_int.category) with pytest.raises(AssertionError): @@ -23,9 +21,7 @@ def test_categorical_access_to_numerical_tunable(tunable_int: Tunable) -> None: def test_numerical_access_to_categorical_tunable(tunable_categorical: Tunable) -> None: - """ - Make sure we throw an error on accessing a numerical tunable as a categorical. - """ + """Make sure we throw an error on accessing a numerical tunable as a categorical.""" with pytest.raises(ValueError): print(tunable_categorical.numerical_value) with pytest.raises(AssertionError): @@ -33,15 +29,11 @@ def test_numerical_access_to_categorical_tunable(tunable_categorical: Tunable) - def test_covariant_group_repr(covariant_group: CovariantTunableGroup) -> None: - """ - Tests that the covariant group representation works as expected. - """ + """Tests that the covariant group representation works as expected.""" assert repr(covariant_group).startswith(f"{covariant_group.name}:") def test_covariant_group_tunables(covariant_group: CovariantTunableGroup) -> None: - """ - Tests that we can access the tunables in the covariant group. - """ + """Tests that we can access the tunables in the covariant group.""" for tunable in covariant_group.get_tunables(): assert isinstance(tunable, Tunable) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 6a91b14016..8d214c051b 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for checking tunable comparisons. -""" +"""Unit tests for checking tunable comparisons.""" import pytest @@ -14,9 +12,7 @@ def test_tunable_int_value_lt(tunable_int: Tunable) -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" tunable_int_2 = tunable_int.copy() tunable_int_2.numerical_value += 1 assert tunable_int.numerical_value < tunable_int_2.numerical_value @@ -24,18 +20,14 @@ def test_tunable_int_value_lt(tunable_int: Tunable) -> None: def test_tunable_int_name_lt(tunable_int: Tunable) -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" tunable_int_2 = tunable_int.copy() tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" tunable_categorical_2 = tunable_categorical.copy() new_value = [ x for x in tunable_categorical.categories @@ -50,9 +42,7 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: def test_tunable_categorical_lt_null() -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" tunable_cat = Tunable( name="same-name", config={ @@ -73,9 +63,7 @@ def test_tunable_categorical_lt_null() -> None: def test_tunable_lt_same_name_different_type() -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" tunable_cat = Tunable( name="same-name", config={ @@ -96,23 +84,17 @@ def test_tunable_lt_same_name_different_type() -> None: def test_tunable_lt_different_object(tunable_int: Tunable) -> None: - """ - Tests that the __lt__ operator works as expected. - """ + """Tests that the __lt__ operator works as expected.""" assert (tunable_int < "foo") is False with pytest.raises(TypeError): assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: - """ - Tests that the __eq__ operator works as expected with other objects. - """ + """Tests that the __eq__ operator works as expected with other objects.""" assert tunable_groups != "foo" def test_covariant_group_ne_object(covariant_group: CovariantTunableGroup) -> None: - """ - Tests that the __eq__ operator works as expected with other objects. - """ + """Tests that the __eq__ operator works as expected with other objects.""" assert covariant_group != "foo" diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index f2da3ba60e..410404d66d 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for checking tunable definition rules. -""" +"""Unit tests for checking tunable definition rules.""" import json5 as json import pytest @@ -13,18 +11,14 @@ def test_tunable_name() -> None: - """ - Check that tunable name is valid. - """ + """Check that tunable name is valid.""" with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) def test_categorical_required_params() -> None: - """ - Check that required parameters are present for categorical tunables. - """ + """Check that required parameters are present for categorical tunables.""" json_config = """ { "type": "categorical", @@ -38,9 +32,7 @@ def test_categorical_required_params() -> None: def test_categorical_weights() -> None: - """ - Instantiate a categorical tunable with weights. - """ + """Instantiate a categorical tunable with weights.""" json_config = """ { "type": "categorical", @@ -55,9 +47,7 @@ def test_categorical_weights() -> None: def test_categorical_weights_wrong_count() -> None: - """ - Try to instantiate a categorical tunable with incorrect number of weights. - """ + """Try to instantiate a categorical tunable with incorrect number of weights.""" json_config = """ { "type": "categorical", @@ -72,9 +62,7 @@ def test_categorical_weights_wrong_count() -> None: def test_categorical_weights_wrong_values() -> None: - """ - Try to instantiate a categorical tunable with invalid weights. - """ + """Try to instantiate a categorical tunable with invalid weights.""" json_config = """ { "type": "categorical", @@ -89,9 +77,7 @@ def test_categorical_weights_wrong_values() -> None: def test_categorical_wrong_params() -> None: - """ - Disallow range param for categorical tunables. - """ + """Disallow range param for categorical tunables.""" json_config = """ { "type": "categorical", @@ -106,9 +92,7 @@ def test_categorical_wrong_params() -> None: def test_categorical_disallow_special_values() -> None: - """ - Disallow special values for categorical values. - """ + """Disallow special values for categorical values.""" json_config = """ { "type": "categorical", @@ -123,9 +107,7 @@ def test_categorical_disallow_special_values() -> None: def test_categorical_tunable_disallow_repeats() -> None: - """ - Disallow duplicate values in categorical tunables. - """ + """Disallow duplicate values in categorical tunables.""" with pytest.raises(ValueError): Tunable(name='test', config={ "type": "categorical", @@ -136,9 +118,7 @@ def test_categorical_tunable_disallow_repeats() -> None: @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None: - """ - Disallow null values as default for numerical tunables. - """ + """Disallow null values as default for numerical tunables.""" with pytest.raises(ValueError): Tunable(name=f'test_{tunable_type}', config={ "type": tunable_type, @@ -149,9 +129,7 @@ def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeN @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None: - """ - Disallow out of range values as default for numerical tunables. - """ + """Disallow out of range values as default for numerical tunables.""" with pytest.raises(ValueError): Tunable(name=f'test_{tunable_type}', config={ "type": tunable_type, @@ -162,9 +140,7 @@ def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeN @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> None: - """ - Disallow values param for numerical tunables. - """ + """Disallow values param for numerical tunables.""" with pytest.raises(ValueError): Tunable(name=f'test_{tunable_type}', config={ "type": tunable_type, @@ -176,9 +152,7 @@ def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> N @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) -> None: - """ - Disallow null values param for numerical tunables. - """ + """Disallow null values param for numerical tunables.""" json_config = f""" {{ "type": "{tunable_type}", @@ -193,9 +167,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> None: - """ - Disallow invalid range param for numerical tunables. - """ + """Disallow invalid range param for numerical tunables.""" json_config = f""" {{ "type": "{tunable_type}", @@ -210,9 +182,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> None: - """ - Disallow reverse range param for numerical tunables. - """ + """Disallow reverse range param for numerical tunables.""" json_config = f""" {{ "type": "{tunable_type}", @@ -227,9 +197,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: - """ - Instantiate a numerical tunable with weighted special values. - """ + """Instantiate a numerical tunable with weighted special values.""" json_config = f""" {{ "type": "{tunable_type}", @@ -249,9 +217,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: - """ - Instantiate a numerical tunable with quantization. - """ + """Instantiate a numerical tunable with quantization.""" json_config = f""" {{ "type": "{tunable_type}", @@ -268,9 +234,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_log(tunable_type: TunableValueTypeName) -> None: - """ - Instantiate a numerical tunable with log scale. - """ + """Instantiate a numerical tunable with log scale.""" json_config = f""" {{ "type": "{tunable_type}", @@ -286,9 +250,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> None: - """ - Raise an error if special_weights are specified but no special values. - """ + """Raise an error if special_weights are specified but no special values.""" json_config = f""" {{ "type": "{tunable_type}", @@ -304,9 +266,8 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> None: - """ - Instantiate a numerical tunable with non-normalized weights - of the special values. + """Instantiate a numerical tunable with non-normalized weights of the special + values. """ json_config = f""" {{ @@ -327,9 +288,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> None: - """ - Try to instantiate a numerical tunable with incorrect number of weights. - """ + """Try to instantiate a numerical tunable with incorrect number of weights.""" json_config = f""" {{ "type": "{tunable_type}", @@ -347,9 +306,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) -> None: - """ - Try to instantiate a numerical tunable with weights but no range_weight. - """ + """Try to instantiate a numerical tunable with weights but no range_weight.""" json_config = f""" {{ "type": "{tunable_type}", @@ -366,9 +323,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) -> None: - """ - Try to instantiate a numerical tunable with specials but no range_weight. - """ + """Try to instantiate a numerical tunable with specials but no range_weight.""" json_config = f""" {{ "type": "{tunable_type}", @@ -385,9 +340,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) -> None: - """ - Try to instantiate a numerical tunable with specials but no range_weight. - """ + """Try to instantiate a numerical tunable with specials but no range_weight.""" json_config = f""" {{ "type": "{tunable_type}", @@ -403,9 +356,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> None: - """ - Try to instantiate a numerical tunable with incorrect number of weights. - """ + """Try to instantiate a numerical tunable with incorrect number of weights.""" json_config = f""" {{ "type": "{tunable_type}", @@ -423,9 +374,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> None: - """ - Instantiate a numerical tunable with invalid number of quantization points. - """ + """Instantiate a numerical tunable with invalid number of quantization points.""" json_config = f""" {{ "type": "{tunable_type}", @@ -440,9 +389,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non def test_bad_type() -> None: - """ - Disallow bad types. - """ + """Disallow bad types.""" json_config = """ { "type": "foo", diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index deffcb6a46..68c560b1cd 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for checking tunable parameters' distributions. -""" +"""Unit tests for checking tunable parameters' distributions.""" import json5 as json import pytest @@ -13,9 +11,7 @@ def test_categorical_distribution() -> None: - """ - Try to instantiate a categorical tunable with distribution specified. - """ + """Try to instantiate a categorical tunable with distribution specified.""" with pytest.raises(ValueError): Tunable(name='test', config={ "type": "categorical", @@ -29,9 +25,7 @@ def test_categorical_distribution() -> None: @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> None: - """ - Create a numeric Tunable with explicit uniform distribution. - """ + """Create a numeric Tunable with explicit uniform distribution.""" tunable = Tunable(name="test", config={ "type": tunable_type, "range": [0, 10], @@ -47,9 +41,7 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> None: - """ - Create a numeric Tunable with explicit Gaussian distribution specified. - """ + """Create a numeric Tunable with explicit Gaussian distribution specified.""" tunable = Tunable(name="test", config={ "type": tunable_type, "range": [0, 10], @@ -68,9 +60,7 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None: - """ - Create a numeric Tunable with explicit Beta distribution specified. - """ + """Create a numeric Tunable with explicit Beta distribution specified.""" tunable = Tunable(name="test", config={ "type": tunable_type, "range": [0, 10], @@ -89,9 +79,7 @@ def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_unsupported(tunable_type: str) -> None: - """ - Create a numeric Tunable with unsupported distribution. - """ + """Create a numeric Tunable with unsupported distribution.""" json_config = f""" {{ "type": "{tunable_type}", diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index c6fb5670f0..eee8a47e3c 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -2,18 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for checking the indexing rules for tunable groups. -""" +"""Tests for checking the indexing rules for tunable groups.""" from mlos_bench.tunables.tunable import Tunable from mlos_bench.tunables.tunable_groups import TunableGroups def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: - """ - Check that various types of indexing work for the tunable group. - """ + """Check that various types of indexing work for the tunable group.""" # Check that the "in" operator works. assert tunable_categorical in tunable_groups assert tunable_categorical.name in tunable_groups diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 55a485e951..274b4d6a43 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -2,16 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for `TunableGroup.subgroup()` method. -""" +"""Tests for `TunableGroup.subgroup()` method.""" from mlos_bench.tunables.tunable_groups import TunableGroups def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: - """ - Check that the subgroup() method returns only a selection of tunable parameters. + """Check that the subgroup() method returns only a selection of tunable + parameters. """ tunables = tunable_groups.subgroup(["provision"]) assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py index 8a9fba6d86..21f9de84d5 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_update_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for checking the is_updated flag for tunable groups. -""" +"""Tests for checking the is_updated flag for tunable groups.""" from mlos_bench.tunables.tunable_groups import TunableGroups @@ -15,16 +13,14 @@ def test_tunable_group_update(tunable_groups: TunableGroups) -> None: - """ - Test that updating a tunable group raises the is_updated flag. - """ + """Test that updating a tunable group raises the is_updated flag.""" tunable_groups.assign(_TUNABLE_VALUES) assert tunable_groups.is_updated() def test_tunable_group_update_twice(tunable_groups: TunableGroups) -> None: - """ - Test that updating a tunable group with the same values do *NOT* raises the is_updated flag. + """Test that updating a tunable group with the same values do *NOT* raises the + is_updated flag. """ tunable_groups.assign(_TUNABLE_VALUES) assert tunable_groups.is_updated() @@ -37,9 +33,7 @@ def test_tunable_group_update_twice(tunable_groups: TunableGroups) -> None: def test_tunable_group_update_kernel(tunable_groups: TunableGroups) -> None: - """ - Test that the is_updated flag is set only for the affected covariant group. - """ + """Test that the is_updated flag is set only for the affected covariant group.""" tunable_groups.assign(_TUNABLE_VALUES) assert tunable_groups.is_updated() assert tunable_groups.is_updated(["kernel"]) @@ -47,9 +41,7 @@ def test_tunable_group_update_kernel(tunable_groups: TunableGroups) -> None: def test_tunable_group_update_boot(tunable_groups: TunableGroups) -> None: - """ - Test that the is_updated flag is set only for the affected covariant group. - """ + """Test that the is_updated flag is set only for the affected covariant group.""" tunable_groups.assign(_TUNABLE_VALUES) assert tunable_groups.is_updated() assert not tunable_groups.is_updated(["boot"]) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py index 8d195dd5cf..9d267d4e16 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_slice_references_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for unique references to tunables when they're loaded multiple times. -""" +"""Unit tests for unique references to tunables when they're loaded multiple times.""" import json5 as json import pytest @@ -13,9 +11,7 @@ def test_duplicate_merging_tunable_groups(tunable_groups_config: dict) -> None: - """ - Check that the merging logic of tunable groups works as expected. - """ + """Check that the merging logic of tunable groups works as expected.""" parent_tunables = TunableGroups(tunable_groups_config) # Pretend we loaded this one from disk another time. @@ -63,9 +59,7 @@ def test_duplicate_merging_tunable_groups(tunable_groups_config: dict) -> None: def test_overlapping_group_merge_tunable_groups(tunable_groups_config: dict) -> None: - """ - Check that the merging logic of tunable groups works as expected. - """ + """Check that the merging logic of tunable groups works as expected.""" parent_tunables = TunableGroups(tunable_groups_config) # This config should overlap with the parent config. @@ -94,9 +88,7 @@ def test_overlapping_group_merge_tunable_groups(tunable_groups_config: dict) -> def test_bad_extended_merge_tunable_group(tunable_groups_config: dict) -> None: - """ - Check that the merging logic of tunable groups works as expected. - """ + """Check that the merging logic of tunable groups works as expected.""" parent_tunables = TunableGroups(tunable_groups_config) # This config should overlap with the parent config. @@ -125,9 +117,7 @@ def test_bad_extended_merge_tunable_group(tunable_groups_config: dict) -> None: def test_good_extended_merge_tunable_group(tunable_groups_config: dict) -> None: - """ - Check that the merging logic of tunable groups works as expected. - """ + """Check that the merging logic of tunable groups works as expected.""" parent_tunables = TunableGroups(tunable_groups_config) # This config should overlap with the parent config. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 73e3a12caa..dfd4b4c610 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -2,9 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for converting tunable parameters with explicitly -specified distributions to ConfigSpace. +"""Unit tests for converting tunable parameters with explicitly specified distributions +to ConfigSpace. """ import pytest @@ -44,9 +43,7 @@ def test_convert_numerical_distributions(param_type: str, distr_name: DistributionName, distr_params: dict) -> None: - """ - Convert a numerical Tunable with explicit distribution to ConfigSpace. - """ + """Convert a numerical Tunable with explicit distribution to ConfigSpace.""" tunable_name = "x" tunable_groups = TunableGroups({ "tunable_group": { diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 78e91fd25e..7936277ec7 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for Tunable to ConfigSpace conversion. -""" +"""Unit tests for Tunable to ConfigSpace conversion.""" import pytest from ConfigSpace import ( @@ -30,8 +28,8 @@ @pytest.fixture def configuration_space() -> ConfigurationSpace: """ - A test fixture that produces a mock ConfigurationSpace object - matching the tunable_groups fixture. + A test fixture that produces a mock ConfigurationSpace object matching the + tunable_groups fixture. Returns ------- @@ -72,9 +70,7 @@ def configuration_space() -> ConfigurationSpace: def _cmp_tunable_hyperparameter_categorical( tunable: Tunable, space: ConfigurationSpace) -> None: - """ - Check if categorical Tunable and ConfigSpace Hyperparameter actually match. - """ + """Check if categorical Tunable and ConfigSpace Hyperparameter actually match.""" param = space[tunable.name] assert isinstance(param, CategoricalHyperparameter) assert set(param.choices) == set(tunable.categories) @@ -83,9 +79,7 @@ def _cmp_tunable_hyperparameter_categorical( def _cmp_tunable_hyperparameter_numerical( tunable: Tunable, space: ConfigurationSpace) -> None: - """ - Check if integer Tunable and ConfigSpace Hyperparameter actually match. - """ + """Check if integer Tunable and ConfigSpace Hyperparameter actually match.""" param = space[tunable.name] assert isinstance(param, (UniformIntegerHyperparameter, UniformFloatHyperparameter)) assert (param.lower, param.upper) == tuple(tunable.range) @@ -94,25 +88,19 @@ def _cmp_tunable_hyperparameter_numerical( def test_tunable_to_configspace_categorical(tunable_categorical: Tunable) -> None: - """ - Check the conversion of Tunable to CategoricalHyperparameter. - """ + """Check the conversion of Tunable to CategoricalHyperparameter.""" cs_param = _tunable_to_configspace(tunable_categorical) _cmp_tunable_hyperparameter_categorical(tunable_categorical, cs_param) def test_tunable_to_configspace_int(tunable_int: Tunable) -> None: - """ - Check the conversion of Tunable to UniformIntegerHyperparameter. - """ + """Check the conversion of Tunable to UniformIntegerHyperparameter.""" cs_param = _tunable_to_configspace(tunable_int) _cmp_tunable_hyperparameter_numerical(tunable_int, cs_param) def test_tunable_to_configspace_float(tunable_float: Tunable) -> None: - """ - Check the conversion of Tunable to UniformFloatHyperparameter. - """ + """Check the conversion of Tunable to UniformFloatHyperparameter.""" cs_param = _tunable_to_configspace(tunable_float) _cmp_tunable_hyperparameter_numerical(tunable_float, cs_param) @@ -127,6 +115,7 @@ def test_tunable_to_configspace_float(tunable_float: Tunable) -> None: def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> None: """ Check the conversion of TunableGroups to ConfigurationSpace. + Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) @@ -136,9 +125,8 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non def test_tunable_groups_to_configspace( tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: - """ - Check the conversion of the entire TunableGroups collection - to a single ConfigurationSpace object. + """Check the conversion of the entire TunableGroups collection to a single + ConfigurationSpace object. """ space = tunable_groups_to_configspace(tunable_groups) assert space == configuration_space diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index cbccd6bfe1..5893e9440a 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -2,8 +2,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for assigning values to the individual parameters within tunable groups. +"""Unit tests for assigning values to the individual parameters within tunable +groups. """ import json5 as json @@ -14,9 +14,8 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: - """ - Make sure that bulk assignment fails for parameters - that don't exist in the TunableGroups object. + """Make sure that bulk assignment fails for parameters that don't exist in the + TunableGroups object. """ with pytest.raises(KeyError): tunable_groups.assign({ @@ -28,118 +27,90 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: - """ - Regular assignment for categorical tunable. - """ + """Regular assignment for categorical tunable.""" # Must be one of: {"Standard_B2s", "Standard_B2ms", "Standard_B4ms"} tunable_categorical.value = "Standard_B4ms" assert not tunable_categorical.is_special def test_tunables_assign_invalid_categorical(tunable_groups: TunableGroups) -> None: - """ - Check parameter validation for categorical tunables. - """ + """Check parameter validation for categorical tunables.""" with pytest.raises(ValueError): tunable_groups.assign({"vmSize": "InvalidSize"}) def test_tunables_assign_invalid_range(tunable_groups: TunableGroups) -> None: - """ - Check parameter out-of-range validation for numerical tunables. - """ + """Check parameter out-of-range validation for numerical tunables.""" with pytest.raises(ValueError): tunable_groups.assign({"kernel_sched_migration_cost_ns": -2}) def test_tunables_assign_coerce_str(tunable_groups: TunableGroups) -> None: - """ - Check the conversion from strings when assigning to an integer parameter. - """ + """Check the conversion from strings when assigning to an integer parameter.""" tunable_groups.assign({"kernel_sched_migration_cost_ns": "10000"}) def test_tunables_assign_coerce_str_range_check(tunable_groups: TunableGroups) -> None: - """ - Check the range when assigning to an integer tunable. - """ + """Check the range when assigning to an integer tunable.""" with pytest.raises(ValueError): tunable_groups.assign({"kernel_sched_migration_cost_ns": "5500000"}) def test_tunables_assign_coerce_str_invalid(tunable_groups: TunableGroups) -> None: - """ - Make sure we fail when assigning an invalid string to an integer tunable. - """ + """Make sure we fail when assigning an invalid string to an integer tunable.""" with pytest.raises(ValueError): tunable_groups.assign({"kernel_sched_migration_cost_ns": "1.1"}) def test_tunable_assign_str_to_numerical(tunable_int: Tunable) -> None: - """ - Check str to int coercion. - """ + """Check str to int coercion.""" with pytest.raises(ValueError): tunable_int.numerical_value = "foo" # type: ignore[assignment] def test_tunable_assign_int_to_numerical_value(tunable_int: Tunable) -> None: - """ - Check numerical value assignment. - """ + """Check numerical value assignment.""" tunable_int.numerical_value = 10.0 assert tunable_int.numerical_value == 10 assert not tunable_int.is_special def test_tunable_assign_float_to_numerical_value(tunable_float: Tunable) -> None: - """ - Check numerical value assignment. - """ + """Check numerical value assignment.""" tunable_float.numerical_value = 0.1 assert tunable_float.numerical_value == 0.1 assert not tunable_float.is_special def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: - """ - Check str to int coercion. - """ + """Check str to int coercion.""" tunable_int.value = "10" assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: - """ - Check str to float coercion. - """ + """Check str to float coercion.""" tunable_float.value = "0.5" assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special def test_tunable_assign_float_to_int(tunable_int: Tunable) -> None: - """ - Check float to int coercion. - """ + """Check float to int coercion.""" tunable_int.value = 10.0 assert tunable_int.value == 10 assert not tunable_int.is_special def test_tunable_assign_float_to_int_fail(tunable_int: Tunable) -> None: - """ - Check the invalid float to int coercion. - """ + """Check the invalid float to int coercion.""" with pytest.raises(ValueError): tunable_int.value = 10.1 def test_tunable_assign_null_to_categorical() -> None: - """ - Checks that we can use null/None in categorical tunables. - """ + """Checks that we can use null/None in categorical tunables.""" json_config = """ { "name": "categorical_test", @@ -159,9 +130,7 @@ def test_tunable_assign_null_to_categorical() -> None: def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: - """ - Checks that we can't use null/None in integer tunables. - """ + """Checks that we can't use null/None in integer tunables.""" with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): @@ -169,9 +138,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: - """ - Checks that we can't use null/None in float tunables. - """ + """Checks that we can't use null/None in float tunables.""" with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): @@ -179,8 +146,8 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: def test_tunable_assign_special(tunable_int: Tunable) -> None: - """ - Check the assignment of a special value outside of the range (but declared `special`). + """Check the assignment of a special value outside of the range (but declared + `special`). """ tunable_int.numerical_value = -1 assert tunable_int.numerical_value == -1 @@ -188,16 +155,16 @@ def test_tunable_assign_special(tunable_int: Tunable) -> None: def test_tunable_assign_special_fail(tunable_int: Tunable) -> None: - """ - Assign a value that is neither special nor in range and fail. - """ + """Assign a value that is neither special nor in range and fail.""" with pytest.raises(ValueError): tunable_int.numerical_value = -2 def test_tunable_assign_special_with_coercion(tunable_int: Tunable) -> None: """ - Check the assignment of a special value outside of the range (but declared `special`). + Check the assignment of a special value outside of the range (but declared + `special`). + Check coercion from float to int. """ tunable_int.numerical_value = -1.0 @@ -207,7 +174,9 @@ def test_tunable_assign_special_with_coercion(tunable_int: Tunable) -> None: def test_tunable_assign_special_with_coercion_str(tunable_int: Tunable) -> None: """ - Check the assignment of a special value outside of the range (but declared `special`). + Check the assignment of a special value outside of the range (but declared + `special`). + Check coercion from string to int. """ tunable_int.value = "-1" diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py index 16bb42500c..c5395fcb16 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for deep copy of tunable objects and groups. -""" +"""Unit tests for deep copy of tunable objects and groups.""" from mlos_bench.tunables.covariant_group import CovariantTunableGroup from mlos_bench.tunables.tunable import Tunable, TunableValue @@ -12,9 +10,7 @@ def test_copy_tunable_int(tunable_int: Tunable) -> None: - """ - Check if deep copy works for Tunable object. - """ + """Check if deep copy works for Tunable object.""" tunable_copy = tunable_int.copy() assert tunable_int == tunable_copy tunable_copy.numerical_value += 200 @@ -22,9 +18,7 @@ def test_copy_tunable_int(tunable_int: Tunable) -> None: def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None: - """ - Check if deep copy works for TunableGroups object. - """ + """Check if deep copy works for TunableGroups object.""" tunable_groups_copy = tunable_groups.copy() assert tunable_groups == tunable_groups_copy tunable_groups_copy["vmSize"] = "Standard_B2ms" @@ -34,9 +28,7 @@ def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None: def test_copy_covariant_group(covariant_group: CovariantTunableGroup) -> None: - """ - Check if deep copy works for TunableGroups object. - """ + """Check if deep copy works for TunableGroups object.""" covariant_group_copy = covariant_group.copy() assert covariant_group == covariant_group_copy tunable = next(iter(covariant_group.get_tunables())) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index 672b16ab73..1f909a63e7 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -2,19 +2,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests to make sure we always produce a string representation -of a TunableGroup in canonical form. +"""Unit tests to make sure we always produce a string representation of a TunableGroup +in canonical form. """ from mlos_bench.tunables.tunable_groups import TunableGroups def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: - """ - Check that we produce the same string representation of TunableGroups, - regardless of the order in which we declare the covariant groups and - tunables within each covariant group. + """Check that we produce the same string representation of TunableGroups, regardless + of the order in which we declare the covariant groups and tunables within each + covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: tunables_other = TunableGroups({ diff --git a/mlos_bench/mlos_bench/tests/util_git_test.py b/mlos_bench/mlos_bench/tests/util_git_test.py index 54946fca6e..77fd2779c7 100644 --- a/mlos_bench/mlos_bench/tests/util_git_test.py +++ b/mlos_bench/mlos_bench/tests/util_git_test.py @@ -2,18 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for get_git_info utility function. -""" +"""Unit tests for get_git_info utility function.""" import re from mlos_bench.util import get_git_info def test_get_git_info() -> None: - """ - Check that we can retrieve git info about the current repository correctly. - """ + """Check that we can retrieve git info about the current repository correctly.""" (git_repo, git_commit, rel_path) = get_git_info(__file__) assert "mlos" in git_repo.lower() assert re.match(r"[0-9a-f]{40}", git_commit) is not None diff --git a/mlos_bench/mlos_bench/tests/util_nullable_test.py b/mlos_bench/mlos_bench/tests/util_nullable_test.py index 28ed7fc92c..f0ca82eb6e 100644 --- a/mlos_bench/mlos_bench/tests/util_nullable_test.py +++ b/mlos_bench/mlos_bench/tests/util_nullable_test.py @@ -2,18 +2,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for `nullable` utility function. -""" +"""Unit tests for `nullable` utility function.""" import pytest from mlos_bench.util import nullable def test_nullable_str() -> None: - """ - Check that the `nullable` function works properly for `str`. - """ + """Check that the `nullable` function works properly for `str`.""" assert nullable(str, None) is None assert nullable(str, "") is not None assert nullable(str, "") == "" @@ -22,9 +18,7 @@ def test_nullable_str() -> None: def test_nullable_int() -> None: - """ - Check that the `nullable` function works properly for `int`. - """ + """Check that the `nullable` function works properly for `int`.""" assert nullable(int, None) is None assert nullable(int, 10) is not None assert nullable(int, 10) == 10 @@ -32,9 +26,7 @@ def test_nullable_int() -> None: def test_nullable_func() -> None: - """ - Check that the `nullable` function works properly with `list.pop()` function. - """ + """Check that the `nullable` function works properly with `list.pop()` function.""" assert nullable(list.pop, None) is None assert nullable(list.pop, [1, 2, 3]) == 3 diff --git a/mlos_bench/mlos_bench/tests/util_try_parse_test.py b/mlos_bench/mlos_bench/tests/util_try_parse_test.py index b613c19694..d97acd0b8c 100644 --- a/mlos_bench/mlos_bench/tests/util_try_parse_test.py +++ b/mlos_bench/mlos_bench/tests/util_try_parse_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for try_parse_val utility function. -""" +"""Unit tests for try_parse_val utility function.""" import math @@ -12,9 +10,7 @@ def test_try_parse_val() -> None: - """ - Check that we can retrieve git info about the current repository correctly. - """ + """Check that we can retrieve git info about the current repository correctly.""" assert try_parse_val(None) is None assert try_parse_val("1") == int(1) assert try_parse_val("1.1") == float(1.1) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 4191f37d89..58106a606e 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tunables classes for Environments in mlos_bench. -""" +"""Tunables classes for Environments in mlos_bench.""" from mlos_bench.tunables.tunable import Tunable, TunableValue from mlos_bench.tunables.tunable_groups import TunableGroups diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index fee4fd5841..3eba2cb9db 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tunable parameter definition. -""" +"""Tunable parameter definition.""" import copy from typing import Dict, Iterable, Union @@ -14,6 +12,7 @@ class CovariantTunableGroup: """ A collection of tunable parameters. + Changing any of the parameters in the group incurs the same cost of the experiment. """ @@ -52,9 +51,9 @@ def name(self) -> str: @property def cost(self) -> int: """ - Get the cost of changing the values in the covariant group. - This value is a constant. Use `get_current_cost()` to get - the cost given the group update status. + Get the cost of changing the values in the covariant group. This value is a + constant. Use `get_current_cost()` to get the cost given the group update + status. Returns ------- @@ -100,8 +99,8 @@ def __eq__(self, other: object) -> bool: def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ - Checks to see if the other CovariantTunableGroup is the same, ignoring - the current values of the two groups' Tunables. + Checks to see if the other CovariantTunableGroup is the same, ignoring the + current values of the two groups' Tunables. Parameters ---------- @@ -126,7 +125,8 @@ def equals_defaults(self, other: "CovariantTunableGroup") -> bool: def is_defaults(self) -> bool: """ - Checks whether the currently assigned values of all tunables are at their defaults. + Checks whether the currently assigned values of all tunables are at their + defaults. Returns ------- @@ -135,9 +135,7 @@ def is_defaults(self) -> bool: return all(tunable.is_default() for tunable in self._tunables.values()) def restore_defaults(self) -> None: - """ - Restore all tunable parameters to their default values. - """ + """Restore all tunable parameters to their default values.""" for tunable in self._tunables.values(): if tunable.value != tunable.default: self._is_updated = True @@ -145,8 +143,10 @@ def restore_defaults(self) -> None: def reset_is_updated(self) -> None: """ - Clear the update flag. That is, state that running an experiment with the - current values of the tunables in this group has no extra cost. + Clear the update flag. + + That is, state that running an experiment with the current values of the + tunables in this group has no extra cost. """ self._is_updated = False @@ -173,9 +173,7 @@ def get_current_cost(self) -> int: return self._cost if self._is_updated else 0 def get_names(self) -> Iterable[str]: - """ - Get the names of all tunables in the group. - """ + """Get the names of all tunables in the group.""" return self._tunables.keys() def get_tunable_values_dict(self) -> Dict[str, TunableValue]: @@ -190,8 +188,8 @@ def get_tunable_values_dict(self) -> Dict[str, TunableValue]: def __repr__(self) -> str: """ - Produce a human-readable version of the CovariantTunableGroup - (mostly for logging). + Produce a human-readable version of the CovariantTunableGroup (mostly for + logging). Returns ------- @@ -202,8 +200,8 @@ def __repr__(self) -> str: def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable: """ - Access the entire Tunable in a group (not just its value). - Throw KeyError if the tunable is not in the group. + Access the entire Tunable in a group (not just its value). Throw KeyError if the + tunable is not in the group. Parameters ---------- @@ -219,7 +217,8 @@ def get_tunable(self, tunable: Union[str, Tunable]) -> Tunable: return self._tunables[name] def get_tunables(self) -> Iterable[Tunable]: - """Gets the set of tunables for this CovariantTunableGroup. + """ + Gets the set of tunables for this CovariantTunableGroup. Returns ------- diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 1ebd70dfa4..7bda45e49e 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tunable parameter definition. -""" +"""Tunable parameter definition.""" import collections import copy import logging @@ -27,34 +25,26 @@ from mlos_bench.util import nullable _LOG = logging.getLogger(__name__) - - """A tunable parameter value type alias.""" TunableValue = Union[int, float, Optional[str]] - """Tunable value type.""" TunableValueType = Union[Type[int], Type[float], Type[str]] - """ Tunable value type tuple. + For checking with isinstance() """ TunableValueTypeTuple = (int, float, str, type(None)) - """The string name of a tunable value type.""" TunableValueTypeName = Literal["int", "float", "categorical"] - -"""Tunable values dictionary type""" +"""Tunable values dictionary type.""" TunableValuesDict = Dict[str, TunableValue] - -"""Tunable value distribution type""" +"""Tunable value distribution type.""" DistributionName = Literal["uniform", "normal", "beta"] class DistributionDict(TypedDict, total=False): - """ - A typed dict for tunable parameters' distributions. - """ + """A typed dict for tunable parameters' distributions.""" type: DistributionName params: Optional[Dict[str, float]] @@ -85,9 +75,7 @@ class TunableDict(TypedDict, total=False): class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods - """ - A tunable parameter definition and its current value. - """ + """A tunable parameter definition and its current value.""" # Maps tunable types to their corresponding Python types by name. _DTYPE: Dict[TunableValueTypeName, TunableValueType] = { @@ -144,8 +132,8 @@ def __init__(self, name: str, config: TunableDict): self.value = self._default def _sanity_check(self) -> None: - """ - Check if the status of the Tunable is valid, and throw ValueError if it is not. + """Check if the status of the Tunable is valid, and throw ValueError if it is + not. """ if self.is_categorical: self._sanity_check_categorical() @@ -157,8 +145,8 @@ def _sanity_check(self) -> None: raise ValueError(f"Invalid default value for tunable {self}: {self.default}") def _sanity_check_categorical(self) -> None: - """ - Check if the status of the categorical Tunable is valid, and throw ValueError if it is not. + """Check if the status of the categorical Tunable is valid, and throw ValueError + if it is not. """ # pylint: disable=too-complex assert self.is_categorical @@ -185,8 +173,8 @@ def _sanity_check_categorical(self) -> None: raise ValueError(f"All weights must be non-negative: {self}") def _sanity_check_numerical(self) -> None: - """ - Check if the status of the numerical Tunable is valid, and throw ValueError if it is not. + """Check if the status of the numerical Tunable is valid, and throw ValueError + if it is not. """ # pylint: disable=too-complex,too-many-branches assert self.is_numerical @@ -303,29 +291,23 @@ def copy(self) -> "Tunable": @property def default(self) -> TunableValue: - """ - Get the default value of the tunable. - """ + """Get the default value of the tunable.""" return self._default def is_default(self) -> TunableValue: - """ - Checks whether the currently assigned value of the tunable is at its default. + """Checks whether the currently assigned value of the tunable is at its + default. """ return self._default == self._current_value @property def value(self) -> TunableValue: - """ - Get the current value of the tunable. - """ + """Get the current value of the tunable.""" return self._current_value @value.setter def value(self, value: TunableValue) -> TunableValue: - """ - Set the current value of the tunable. - """ + """Set the current value of the tunable.""" # We need this coercion for the values produced by some optimizers # (e.g., scikit-optimize) and for data restored from certain storage # systems (where values can be strings). @@ -355,7 +337,8 @@ def value(self, value: TunableValue) -> TunableValue: def update(self, value: TunableValue) -> bool: """ - Assign the value to the tunable. Return True if it is a new value, False otherwise. + Assign the value to the tunable. Return True if it is a new value, False + otherwise. Parameters ---------- @@ -399,8 +382,9 @@ def is_valid(self, value: TunableValue) -> bool: def in_range(self, value: Union[int, float, str, None]) -> bool: """ Check if the value is within the range of the tunable. - Do *NOT* check for special values. - Return False if the tunable or value is categorical or None. + + Do *NOT* check for special values. Return False if the tunable or value is + categorical or None. """ return ( isinstance(value, (float, int)) and @@ -411,9 +395,7 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: @property def category(self) -> Optional[str]: - """ - Get the current value of the tunable as a number. - """ + """Get the current value of the tunable as a number.""" if self.is_categorical: return nullable(str, self._current_value) else: @@ -421,9 +403,7 @@ def category(self) -> Optional[str]: @category.setter def category(self, new_value: Optional[str]) -> Optional[str]: - """ - Set the current value of the tunable. - """ + """Set the current value of the tunable.""" assert self.is_categorical assert isinstance(new_value, (str, type(None))) self.value = new_value @@ -431,9 +411,7 @@ def category(self, new_value: Optional[str]) -> Optional[str]: @property def numerical_value(self) -> Union[int, float]: - """ - Get the current value of the tunable as a number. - """ + """Get the current value of the tunable as a number.""" assert self._current_value is not None if self._type == "int": return int(self._current_value) @@ -444,9 +422,7 @@ def numerical_value(self) -> Union[int, float]: @numerical_value.setter def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]: - """ - Set the current numerical value of the tunable. - """ + """Set the current numerical value of the tunable.""" # We need this coercion for the values produced by some optimizers # (e.g., scikit-optimize) and for data restored from certain storage # systems (where values can be strings). @@ -456,9 +432,7 @@ def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]: @property def name(self) -> str: - """ - Get the name / string ID of the tunable. - """ + """Get the name / string ID of the tunable.""" return self._name @property @@ -488,8 +462,8 @@ def is_special(self) -> bool: @property def weights(self) -> Optional[List[float]]: """ - Get the weights of the categories or special values of the tunable. - Return None if there are none. + Get the weights of the categories or special values of the tunable. Return None + if there are none. Returns ------- @@ -501,8 +475,8 @@ def weights(self) -> Optional[List[float]]: @property def range_weight(self) -> Optional[float]: """ - Get weight of the range of the numeric tunable. - Return None if there are no weights or a tunable is categorical. + Get weight of the range of the numeric tunable. Return None if there are no + weights or a tunable is categorical. Returns ------- @@ -693,8 +667,8 @@ def distribution_params(self) -> Dict[str, float]: @property def categories(self) -> List[Optional[str]]: """ - Get the list of all possible values of a categorical tunable. - Return None if the tunable is not categorical. + Get the list of all possible values of a categorical tunable. Return None if the + tunable is not categorical. Returns ------- @@ -723,7 +697,9 @@ def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Itera @property def meta(self) -> Dict[str, Any]: """ - Get the tunable's metadata. This is a free-form dictionary that can be used to - store any additional information about the tunable (e.g., the unit information). + Get the tunable's metadata. + + This is a free-form dictionary that can be used to store any additional + information about the tunable (e.g., the unit information). """ return self._meta diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 0bd58c8269..bc56f20f45 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -TunableGroups definition. -""" +"""TunableGroups definition.""" import copy from typing import Dict, Generator, Iterable, Mapping, Optional, Tuple, Union @@ -14,9 +12,7 @@ class TunableGroups: - """ - A collection of covariant groups of tunable parameters. - """ + """A collection of covariant groups of tunable parameters.""" def __init__(self, config: Optional[dict] = None): """ @@ -138,23 +134,17 @@ def __repr__(self) -> str: for tunable in sorted(group._tunables.values())) + " }" def __contains__(self, tunable: Union[str, Tunable]) -> bool: - """ - Checks if the given name/tunable is in this tunable group. - """ + """Checks if the given name/tunable is in this tunable group.""" name: str = tunable.name if isinstance(tunable, Tunable) else tunable return name in self._index def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: - """ - Get the current value of a single tunable parameter. - """ + """Get the current value of a single tunable parameter.""" name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: - """ - Update the current value of a single tunable parameter. - """ + """Update the current value of a single tunable parameter.""" # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value @@ -175,8 +165,8 @@ def __iter__(self) -> Generator[Tuple[Tunable, CovariantTunableGroup], None, Non def get_tunable(self, tunable: Union[str, Tunable]) -> Tuple[Tunable, CovariantTunableGroup]: """ - Access the entire Tunable (not just its value) and its covariant group. - Throw KeyError if the tunable is not found. + Access the entire Tunable (not just its value) and its covariant group. Throw + KeyError if the tunable is not found. Parameters ---------- @@ -205,8 +195,8 @@ def get_covariant_group_names(self) -> Iterable[str]: def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": """ - Select the covariance groups from the current set and create a new - TunableGroups object that consists of those covariance groups. + Select the covariance groups from the current set and create a new TunableGroups + object that consists of those covariance groups. Note: The new TunableGroup will include *references* (not copies) to original ones, so each will get updated together. @@ -235,7 +225,8 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": def get_param_values(self, group_names: Optional[Iterable[str]] = None, into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: """ - Get the current values of the tunables that belong to the specified covariance groups. + Get the current values of the tunables that belong to the specified covariance + groups. Parameters ---------- @@ -277,7 +268,8 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: def is_defaults(self) -> bool: """ - Checks whether the currently assigned values of all tunables are at their defaults. + Checks whether the currently assigned values of all tunables are at their + defaults. Returns ------- @@ -323,8 +315,8 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": def assign(self, param_values: Mapping[str, TunableValue]) -> "TunableGroups": """ - In-place update the values of the tunables from the dictionary - of (key, value) pairs. + In-place update the values of the tunables from the dictionary of (key, value) + pairs. Parameters ---------- diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 531988be97..d516eb5337 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Various helper functions for mlos_bench. -""" +"""Various helper functions for mlos_bench.""" # NOTE: This has to be placed in the top-level mlos_bench package to avoid circular imports. @@ -48,8 +46,8 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> dict: """ - Replaces all $name values in the destination config with the corresponding - value from the source config. + Replaces all $name values in the destination config with the corresponding value + from the source config. Parameters ---------- @@ -74,9 +72,8 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> def merge_parameters(*, dest: dict, source: Optional[dict] = None, required_keys: Optional[Iterable[str]] = None) -> dict: """ - Merge the source config dict into the destination config. - Pick from the source configs *ONLY* the keys that are already present - in the destination config. + Merge the source config dict into the destination config. Pick from the source + configs *ONLY* the keys that are already present in the destination config. Parameters ---------- @@ -222,8 +219,8 @@ def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, def check_required_params(config: Mapping[str, Any], required_params: Iterable[str]) -> None: """ - Check if all required parameters are present in the configuration. - Raise ValueError if any of the parameters are missing. + Check if all required parameters are present in the configuration. Raise ValueError + if any of the parameters are missing. Parameters ---------- @@ -356,9 +353,7 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: - """ - A nullable version of utcify_timestamp. - """ + """A nullable version of utcify_timestamp.""" return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index 96d3d2b6bf..520192b647 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Version number for the mlos_bench package. -""" +"""Version number for the mlos_bench package.""" # NOTE: This should be managed by bumpversion. VERSION = '0.5.1' diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 27d844c35b..9e00657dfa 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Setup instructions for the mlos_bench package. -""" +"""Setup instructions for the mlos_bench package.""" # pylint: disable=duplicate-code diff --git a/mlos_core/mlos_core/__init__.py b/mlos_core/mlos_core/__init__.py index 3d816eb916..41d24af928 100644 --- a/mlos_core/mlos_core/__init__.py +++ b/mlos_core/mlos_core/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Basic initializer module for the mlos_core package. -""" +"""Basic initializer module for the mlos_core package.""" diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index 086002af62..c72600be02 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Basic initializer module for the mlos_core optimizers. -""" +"""Basic initializer module for the mlos_core optimizers.""" from enum import Enum from typing import List, Optional, TypeVar @@ -31,13 +29,13 @@ class OptimizerType(Enum): """Enumerate supported MlosCore optimizers.""" RANDOM = RandomOptimizer - """An instance of RandomOptimizer class will be used""" + """An instance of RandomOptimizer class will be used.""" FLAML = FlamlOptimizer - """An instance of FlamlOptimizer class will be used""" + """An instance of FlamlOptimizer class will be used.""" SMAC = SmacOptimizer - """An instance of SmacOptimizer class will be used""" + """An instance of SmacOptimizer class will be used.""" # To make mypy happy, we need to define a type variable for each optimizer type. @@ -55,7 +53,7 @@ class OptimizerType(Enum): class OptimizerFactory: - """Simple factory class for creating BaseOptimizer-derived objects""" + """Simple factory class for creating BaseOptimizer-derived objects.""" # pylint: disable=too-few-public-methods @@ -68,8 +66,8 @@ def create(*, space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] """ - Create a new optimizer instance, given the parameter space, optimizer type, - and potential optimizer options. + Create a new optimizer instance, given the parameter space, optimizer type, and + potential optimizer options. Parameters ---------- diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index 5f32219988..d4b7294f32 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Basic initializer module for the mlos_core Bayesian optimizers. -""" +"""Basic initializer module for the mlos_core Bayesian optimizers.""" from mlos_core.optimizers.bayesian_optimizers.bayesian_optimizer import ( BaseBayesianOptimizer, diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 76ff0d9b3a..de333be46e 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the wrapper classes for base Bayesian optimizers. -""" +"""Contains the wrapper classes for base Bayesian optimizers.""" from abc import ABCMeta, abstractmethod from typing import Optional @@ -21,7 +19,9 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): @abstractmethod def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: - """Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). + """ + Obtain a prediction from this Bayesian optimizer's surrogate model for the given + configuration(s). Parameters ---------- @@ -36,7 +36,9 @@ def surrogate_predict(self, *, configs: pd.DataFrame, @abstractmethod def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: - """Invokes the acquisition function from this Bayesian optimizer for the given configuration. + """ + Invokes the acquisition function from this Bayesian optimizer for the given + configuration. Parameters ---------- diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 9d8d2a0347..e86d868cdb 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -4,6 +4,7 @@ # """ Contains the wrapper class for SMAC Bayesian optimizers. + See Also: """ @@ -25,9 +26,7 @@ class SmacOptimizer(BaseBayesianOptimizer): - """ - Wrapper class for SMAC based Bayesian optimization. - """ + """Wrapper class for SMAC based Bayesian optimization.""" def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments parameter_space: ConfigSpace.ConfigurationSpace, @@ -212,7 +211,8 @@ def __del__(self) -> None: @property def n_random_init(self) -> int: """ - Gets the number of random samples to use to initialize the optimizer's search space sampling. + Gets the number of random samples to use to initialize the optimizer's search + space sampling. Note: This may not be equal to the value passed to the initializer, due to logic present in the SMAC. See Also: max_ratio @@ -227,7 +227,8 @@ def n_random_init(self) -> int: @staticmethod def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None: - """Dummy target function for SMAC optimizer. + """ + Dummy target function for SMAC optimizer. Since we only use the ask-and-tell interface, this is never called. @@ -245,7 +246,8 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Registers the given configs and scores. + """ + Registers the given configs and scores. Parameters ---------- @@ -282,7 +284,8 @@ def _register(self, *, configs: pd.DataFrame, self.base_optimizer.optimizer.save() def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: - """Suggests a new configuration. + """ + Suggests a new configuration. Parameters ---------- @@ -359,7 +362,8 @@ def cleanup(self) -> None: self._temp_output_directory = None def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace.Configuration]: - """Convert a dataframe of configs to a list of ConfigSpace configs. + """ + Convert a dataframe of configs to a list of ConfigSpace configs. Parameters ---------- diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 273c89eecc..aaefdbdf3d 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the FlamlOptimizer class. -""" +"""Contains the FlamlOptimizer class.""" from typing import Dict, List, NamedTuple, Optional, Tuple, Union from warnings import warn @@ -26,9 +24,7 @@ class EvaluatedSample(NamedTuple): class FlamlOptimizer(BaseOptimizer): - """ - Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning. - """ + """Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning.""" # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" @@ -90,7 +86,8 @@ def __init__(self, *, # pylint: disable=too-many-arguments def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Registers the given configs and scores. + """ + Registers the given configs and scores. Parameters ---------- @@ -122,7 +119,8 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, ) def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: - """Suggests a new configuration. + """ + Suggests a new configuration. Sampled at random using ConfigSpace. @@ -149,7 +147,8 @@ def register_pending(self, *, configs: pd.DataFrame, raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: - """Configuration evaluation function called by FLAML optimizer. + """ + Configuration evaluation function called by FLAML optimizer. FLAML may suggest the same configuration multiple times (due to its warm-start mechanism). Once FLAML suggests an unseen configuration, we store it, and stop the optimization process. @@ -173,7 +172,9 @@ def _target_function(self, config: dict) -> Union[dict, None]: return None # Returning None stops the process def _get_next_config(self) -> dict: - """Warm-starts a new instance of FLAML, and returns a recommended, unseen new configuration. + """ + Warm-starts a new instance of FLAML, and returns a recommended, unseen new + configuration. Since FLAML does not provide an ask-and-tell interface, we need to create a new instance of FLAML each time we get asked for a new suggestion. This is suboptimal performance-wise, but works. diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 4ab9db5a2f..d9b37910b5 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the BaseOptimizer abstract class. -""" +"""Contains the BaseOptimizer abstract class.""" import collections from abc import ABCMeta, abstractmethod @@ -20,9 +18,7 @@ class BaseOptimizer(metaclass=ABCMeta): - """ - Optimizer abstract base class defining the basic interface. - """ + """Optimizer abstract base class defining the basic interface.""" def __init__(self, *, parameter_space: ConfigSpace.ConfigurationSpace, @@ -70,7 +66,9 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Wrapper method, which employs the space adapter (if any), before registering the configs and scores. + """ + Wrapper method, which employs the space adapter (if any), before registering the + configs and scores. Parameters ---------- @@ -110,7 +108,8 @@ def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, @abstractmethod def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Registers the given configs and scores. + """ + Registers the given configs and scores. Parameters ---------- @@ -127,7 +126,8 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, def suggest(self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ - Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. + Wrapper method, which employs the space adapter (if any), after suggesting a new + configuration. Parameters ---------- @@ -161,7 +161,8 @@ def suggest(self, *, context: Optional[pd.DataFrame] = None, @abstractmethod def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: - """Suggests a new configuration. + """ + Suggests a new configuration. Parameters ---------- @@ -182,9 +183,10 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr def register_pending(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Registers the given configs as "pending". - That is it say, it has been suggested by the optimizer, and an experiment trial has been started. - This can be useful for executing multiple trials in parallel, retry logic, etc. + """ + Registers the given configs as "pending". That is it say, it has been suggested + by the optimizer, and an experiment trial has been started. This can be useful + for executing multiple trials in parallel, retry logic, etc. Parameters ---------- @@ -216,9 +218,10 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ - Get the N best observations so far as a triplet of DataFrames (config, score, context). - Default is N=1. The columns are ordered in ASCENDING order of the optimization targets. - The function uses `pandas.DataFrame.nsmallest(..., keep="first")` method under the hood. + Get the N best observations so far as a triplet of DataFrames (config, score, + context). Default is N=1. The columns are ordered in ASCENDING order of the + optimization targets. The function uses `pandas.DataFrame.nsmallest(..., + keep="first")` method under the hood. Parameters ---------- @@ -239,14 +242,15 @@ def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.Dat def cleanup(self) -> None: """ - Remove temp files, release resources, etc. after use. Default is no-op. - Redefine this method in optimizers that require cleanup. + Remove temp files, release resources, etc. + + after use. Default is no-op. Redefine this method in optimizers that require + cleanup. """ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: - """ - Convert numpy array from one-hot encoding to a DataFrame - with categoricals and ints in proper columns. + """Convert numpy array from one-hot encoding to a DataFrame with categoricals + and ints in proper columns. """ df_dict = collections.defaultdict(list) for i in range(config.shape[0]): @@ -267,9 +271,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: return pd.DataFrame(df_dict) def _to_1hot(self, *, config: Union[pd.DataFrame, pd.Series]) -> npt.NDArray: - """ - Convert pandas DataFrame to one-hot-encoded numpy array. - """ + """Convert pandas DataFrame to one-hot-encoded numpy array.""" n_cols = 0 n_rows = config.shape[0] if config.ndim > 1 else 1 for param in self.optimizer_parameter_space.values(): diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index 0af785ef20..7f83b8e086 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the RandomOptimizer class. -""" +"""Contains the RandomOptimizer class.""" from typing import Optional, Tuple from warnings import warn @@ -15,8 +13,9 @@ class RandomOptimizer(BaseOptimizer): - """Optimizer class that produces random suggestions. - Useful for baseline comparison against Bayesian optimizers. + """ + Optimizer class that produces random suggestions. Useful for baseline comparison + against Bayesian optimizers. Parameters ---------- @@ -26,7 +25,8 @@ class RandomOptimizer(BaseOptimizer): def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: - """Registers the given configs and scores. + """ + Registers the given configs and scores. Doesn't do anything on the RandomOptimizer except storing configs for logging. @@ -51,7 +51,8 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, # should we pop them from self.pending_observations? def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: - """Suggests a new configuration. + """ + Suggests a new configuration. Sampled at random using ConfigSpace. diff --git a/mlos_core/mlos_core/spaces/__init__.py b/mlos_core/mlos_core/spaces/__init__.py index d2a636ff1a..8de6887783 100644 --- a/mlos_core/mlos_core/spaces/__init__.py +++ b/mlos_core/mlos_core/spaces/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Space adapters and converters init file. -""" +"""Space adapters and converters init file.""" diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 2e2f585590..8618707f9a 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Basic initializer module for the mlos_core space adapters. -""" +"""Basic initializer module for the mlos_core space adapters.""" from enum import Enum from typing import Optional, TypeVar @@ -24,10 +22,10 @@ class SpaceAdapterType(Enum): """Enumerate supported MlosCore space adapters.""" IDENTITY = IdentityAdapter - """A no-op adapter will be used""" + """A no-op adapter will be used.""" LLAMATUNE = LlamaTuneAdapter - """An instance of LlamaTuneAdapter class will be used""" + """An instance of LlamaTuneAdapter class will be used.""" # To make mypy happy, we need to define a type variable for each optimizer type. @@ -42,7 +40,7 @@ class SpaceAdapterType(Enum): class SpaceAdapterFactory: - """Simple factory class for creating BaseSpaceAdapter-derived objects""" + """Simple factory class for creating BaseSpaceAdapter-derived objects.""" # pylint: disable=too-few-public-methods diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index 6c3a86fc8a..f28ab694a4 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the BaseSpaceAdapter abstract class. -""" +"""Contains the BaseSpaceAdapter abstract class.""" from abc import ABCMeta, abstractmethod @@ -13,7 +11,8 @@ class BaseSpaceAdapter(metaclass=ABCMeta): - """SpaceAdapter abstract class defining the basic interface. + """ + SpaceAdapter abstract class defining the basic interface. Parameters ---------- @@ -35,23 +34,21 @@ def __repr__(self) -> str: @property def orig_parameter_space(self) -> ConfigSpace.ConfigurationSpace: - """ - Original (user-provided) parameter space to explore. - """ + """Original (user-provided) parameter space to explore.""" return self._orig_parameter_space @property @abstractmethod def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: - """ - Target parameter space that is fed to the underlying optimizer. - """ + """Target parameter space that is fed to the underlying optimizer.""" pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: - """Translates a configuration, which belongs to the target parameter space, to the original parameter space. - This method is called by the `suggest` method of the `BaseOptimizer` class. + """ + Translates a configuration, which belongs to the target parameter space, to the + original parameter space. This method is called by the `suggest` method of the + `BaseOptimizer` class. Parameters ---------- @@ -68,9 +65,11 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: - """Translates a configuration, which belongs to the original parameter space, to the target parameter space. - This method is called by the `register` method of the `BaseOptimizer` class, and performs the inverse operation - of `BaseSpaceAdapter.transform` method. + """ + Translates a configuration, which belongs to the original parameter space, to + the target parameter space. This method is called by the `register` method of + the `BaseOptimizer` class, and performs the inverse operation of + `BaseSpaceAdapter.transform` method. Parameters ---------- diff --git a/mlos_core/mlos_core/spaces/adapters/identity_adapter.py b/mlos_core/mlos_core/spaces/adapters/identity_adapter.py index ad79fa21c9..1e552110a2 100644 --- a/mlos_core/mlos_core/spaces/adapters/identity_adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/identity_adapter.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains the Identity (no-op) Space Adapter class. -""" +"""Contains the Identity (no-op) Space Adapter class.""" import ConfigSpace import pandas as pd @@ -13,7 +11,8 @@ class IdentityAdapter(BaseSpaceAdapter): - """Identity (no-op) SpaceAdapter class. + """ + Identity (no-op) SpaceAdapter class. Parameters ---------- diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 4d3a925cbc..8a416d40ab 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Implementation of LlamaTune space adapter. -""" +"""Implementation of LlamaTune space adapter.""" from typing import Dict, Optional from warnings import warn @@ -20,19 +18,22 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes - """ - Implementation of LlamaTune, a set of parameter space transformation techniques, + """Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. """ DEFAULT_NUM_LOW_DIMS = 16 - """Default number of dimensions in the low-dimensional search space, generated by HeSBO projection""" + """Default number of dimensions in the low-dimensional search space, generated by + HeSBO projection. + """ DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 - """Default percentage of bias for each special parameter value""" + """Default percentage of bias for each special parameter value.""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 - """Default number of (max) unique values of each parameter, when space discretization is used""" + """Default number of (max) unique values of each parameter, when space + discretization is used. + """ def __init__(self, *, orig_parameter_space: ConfigSpace.ConfigurationSpace, @@ -141,7 +142,8 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: - """Constructs the low-dimensional parameter (potentially discretized) search space. + """ + Constructs the low-dimensional parameter (potentially discretized) search space. Parameters ---------- @@ -183,8 +185,10 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para self._target_config_space = config_space def _transform(self, configuration: dict) -> dict: - """Projects a low-dimensional point (configuration) to the high-dimensional original parameter space, - and then biases the resulting parameter values towards their special value(s) (if any). + """ + Projects a low-dimensional point (configuration) to the high-dimensional + original parameter space, and then biases the resulting parameter values towards + their special value(s) (if any). Parameters ---------- @@ -237,7 +241,9 @@ def _transform(self, configuration: dict) -> dict: return original_config def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: - """Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. + """ + Biases the special value(s) of this parameter, by shifting the normalized + `input_value` towards those. Parameters ---------- @@ -270,8 +276,9 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara # pylint: disable=too-complex,too-many-branches def _validate_special_param_values(self, special_param_values_dict: dict) -> None: - """Checks that the user-provided dict of special parameter values is valid. - And assigns it to the corresponding attribute. + """ + Checks that the user-provided dict of special parameter values is valid. And + assigns it to the corresponding attribute. Parameters ---------- diff --git a/mlos_core/mlos_core/spaces/converters/__init__.py b/mlos_core/mlos_core/spaces/converters/__init__.py index 8385a4938d..2360bda24f 100644 --- a/mlos_core/mlos_core/spaces/converters/__init__.py +++ b/mlos_core/mlos_core/spaces/converters/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Space converters init file. -""" +"""Space converters init file.""" diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index d6918f9891..4aee0154b6 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Contains space converters for FLAML. -""" +"""Contains space converters for FLAML.""" import sys from typing import TYPE_CHECKING, Dict @@ -28,7 +26,8 @@ def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: - """Converts a ConfigSpace.ConfigurationSpace to dict. + """ + Converts a ConfigSpace.ConfigurationSpace to dict. Parameters ---------- diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index a8ad146205..6a0962f415 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Common functions for mlos_core Optimizer tests. -""" +"""Common functions for mlos_core Optimizer tests.""" import sys from importlib import import_module @@ -27,6 +25,7 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: """ Imports all submodules for a package and returns their names. + Useful for dynamically enumerating subclasses. """ submodules = [] @@ -38,6 +37,7 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: """ Gets the set of all of the subclasses of the given class. + Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( @@ -46,8 +46,8 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: """ - Gets a sorted list of all of the concrete subclasses of the given class. - Useful for dynamically enumerating expected test cases. + Gets a sorted list of all of the concrete subclasses of the given class. Useful for + dynamically enumerating expected test cases. Note: For abstract types, mypy will complain at the call site. Use "# type: ignore[type-abstract]" to suppress the warning. diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index c7a94dfcc4..e0b094e4d6 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for Bayesian Optimizers. -""" +"""Tests for Bayesian Optimizers.""" from typing import Optional, Type @@ -23,8 +21,8 @@ def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: - """ - Make sure we raise warnings for the functionality that has not been implemented yet. + """Make sure we raise warnings for the functionality that has not been implemented + yet. """ if kwargs is None: kwargs = {} diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 39231bec5c..417b917552 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test fixtures for mlos_bench optimizers. -""" +"""Test fixtures for mlos_bench optimizers.""" import ConfigSpace as CS import pytest @@ -12,9 +10,7 @@ @pytest.fixture def configuration_space() -> CS.ConfigurationSpace: - """ - Test fixture to produce a config space with all types of hyperparameters. - """ + """Test fixture to produce a config space with all types of hyperparameters.""" # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 725d92fbe9..f9fe07fbf0 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for one-hot encoding for certain optimizers. -""" +"""Tests for one-hot encoding for certain optimizers.""" import ConfigSpace as CS import numpy as np @@ -21,6 +19,7 @@ def data_frame() -> pd.DataFrame: """ Toy data frame corresponding to the `configuration_space` hyperparameters. + The columns are deliberately *not* in alphabetic order. """ return pd.DataFrame({ @@ -34,6 +33,7 @@ def data_frame() -> pd.DataFrame: def one_hot_data_frame() -> npt.NDArray: """ One-hot encoding of the `data_frame` above. + The columns follow the order of the hyperparameters in `configuration_space`. """ return np.array([ @@ -47,6 +47,7 @@ def one_hot_data_frame() -> npt.NDArray: def series() -> pd.Series: """ Toy series corresponding to the `configuration_space` hyperparameters. + The columns are deliberately *not* in alphabetic order. """ return pd.Series({ @@ -60,6 +61,7 @@ def series() -> pd.Series: def one_hot_series() -> npt.NDArray: """ One-hot encoding of the `series` above. + The columns follow the order of the hyperparameters in `configuration_space`. """ return np.array([ @@ -70,7 +72,9 @@ def one_hot_series() -> npt.NDArray: @pytest.fixture def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ - Test fixture for the optimizer. Use it to test one-hot encoding/decoding. + Test fixture for the optimizer. + + Use it to test one-hot encoding/decoding. """ return SmacOptimizer( parameter_space=configuration_space, @@ -81,44 +85,34 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: def test_to_1hot_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray) -> None: - """ - Toy problem to test one-hot encoding of dataframe. - """ + """Toy problem to test one-hot encoding of dataframe.""" assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) def test_to_1hot_series(optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray) -> None: - """ - Toy problem to test one-hot encoding of series. - """ + """Toy problem to test one-hot encoding of series.""" assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) def test_from_1hot_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray) -> None: - """ - Toy problem to test one-hot decoding of dataframe. - """ + """Toy problem to test one-hot decoding of dataframe.""" assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() def test_from_1hot_series(optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray) -> None: - """ - Toy problem to test one-hot decoding of series. - """ + """Toy problem to test one-hot decoding of series.""" one_hot_df = optimizer._from_1hot(config=one_hot_series) assert one_hot_df.shape[0] == 1, f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" assert one_hot_df.iloc[0].to_dict() == series.to_dict() def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFrame) -> None: - """ - Round-trip test for one-hot-encoding and then decoding a data frame. - """ + """Round-trip test for one-hot-encoding and then decoding a data frame.""" df_round_trip = optimizer._from_1hot(config=optimizer._to_1hot(config=data_frame)) assert df_round_trip.x.to_numpy() == pytest.approx(data_frame.x) assert (df_round_trip.y == data_frame.y).all() @@ -126,9 +120,7 @@ def test_round_trip_data_frame(optimizer: BaseOptimizer, data_frame: pd.DataFram def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> None: - """ - Round-trip test for one-hot-encoding and then decoding a series. - """ + """Round-trip test for one-hot-encoding and then decoding a series.""" series_round_trip = optimizer._from_1hot(config=optimizer._to_1hot(config=series)) assert series_round_trip.x.to_numpy() == pytest.approx(series.x) assert (series_round_trip.y == series.y).all() @@ -137,17 +129,13 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray) -> None: - """ - Round-trip test for one-hot-decoding and then encoding of a numpy array. - """ + """Round-trip test for one-hot-decoding and then encoding of a numpy array.""" round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame)) assert round_trip == pytest.approx(one_hot_data_frame) def test_round_trip_reverse_series(optimizer: BaseOptimizer, one_hot_series: npt.NDArray) -> None: - """ - Round-trip test for one-hot-decoding and then encoding of a numpy array. - """ + """Round-trip test for one-hot-decoding and then encoding of a numpy array.""" round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_series)) assert round_trip == pytest.approx(one_hot_series) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 0b9d624a7a..271bfce1d8 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Test multi-target optimization. -""" +"""Test multi-target optimization.""" import logging from typing import List, Optional, Type @@ -24,8 +22,7 @@ *[(member.value, {}) for member in OptimizerType], ]) def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: - """ - Make sure that the optimizer raises an error if the number of objective weights + """Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. """ with pytest.raises(ValueError): @@ -48,9 +45,8 @@ def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kw def test_multi_target_opt(objective_weights: Optional[List[float]], optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: - """ - Toy multi-target optimization problem to test the optimizers with - mixed numeric types to ensure that original dtypes are retained. + """Toy multi-target optimization problem to test the optimizers with mixed numeric + types to ensure that original dtypes are retained. """ max_iterations = 10 diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 5fd28ca1ed..b1c68ad136 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for Bayesian Optimizers. -""" +"""Tests for Bayesian Optimizers.""" import logging from copy import deepcopy @@ -37,9 +35,7 @@ ]) def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: - """ - Test that we can create an optimizer and get a suggestion from it. - """ + """Test that we can create an optimizer and get a suggestion from it.""" if kwargs is None: kwargs = {} optimizer = optimizer_class( @@ -67,9 +63,7 @@ def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace ]) def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: - """ - Toy problem to test the optimizers. - """ + """Toy problem to test the optimizers.""" # pylint: disable=too-many-locals max_iterations = 20 if kwargs is None: @@ -143,9 +137,7 @@ def objective(x: pd.Series) -> pd.DataFrame: *list(OptimizerType), ]) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: - """ - Test that all optimizer types are listed in the ConcreteOptimizer constraints. - """ + """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member @@ -158,9 +150,7 @@ def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: ]) def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: - """ - Test that we can create an optimizer via a factory. - """ + """Test that we can create an optimizer via a factory.""" if kwargs is None: kwargs = {} if optimizer_type is None: @@ -199,9 +189,7 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat }), ]) def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: - """ - Toy problem to test the optimizers with llamatune space adapter. - """ + """Toy problem to test the optimizers with llamatune space adapter.""" # pylint: disable=too-complex,disable=too-many-statements,disable=too-many-locals num_iters = 50 if kwargs is None: @@ -327,9 +315,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: @pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: - """ - Test that all optimizer classes are listed in the OptimizerType enum. - """ + """Test that all optimizer classes are listed in the OptimizerType enum.""" optimizer_type_classes = {member.value for member in OptimizerType} assert optimizer_class in optimizer_type_classes @@ -342,8 +328,8 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: # Optimizer with non-empty kwargs argument ]) def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: - """ - Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. + """Toy problem to test the optimizers with mixed numeric types to ensure that + original dtypes are retained. """ max_iterations = 10 if kwargs is None: diff --git a/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py index e7f3fb9d3e..b3b79ffadb 100644 --- a/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/random_optimizer_test.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for random optimizer. -""" +"""Tests for random optimizer.""" diff --git a/mlos_core/mlos_core/tests/spaces/__init__.py b/mlos_core/mlos_core/tests/spaces/__init__.py index 489802cb5a..a4112b6081 100644 --- a/mlos_core/mlos_core/tests/spaces/__init__.py +++ b/mlos_core/mlos_core/tests/spaces/__init__.py @@ -2,6 +2,4 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Basic initializer module for the mlos_core.tests.spaces package. -""" +"""Basic initializer module for the mlos_core.tests.spaces package.""" diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 37b8aa3a69..07f23507d9 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for Identity space adapter. -""" +"""Tests for Identity space adapter.""" # pylint: disable=missing-function-docstring @@ -15,9 +13,7 @@ def test_identity_adapter() -> None: - """ - Tests identity adapter - """ + """Tests identity adapter.""" input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 84dcd4e5c0..5bddbaf807 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for LlamaTune space adapter. -""" +"""Tests for LlamaTune space adapter.""" # pylint: disable=missing-function-docstring @@ -23,9 +21,7 @@ def construct_parameter_space( n_categorical_params: int = 0, seed: int = 1234, ) -> CS.ConfigurationSpace: - """ - Helper function for construct an instance of `ConfigSpace.ConfigurationSpace`. - """ + """Helper function for construct an instance of `ConfigSpace.ConfigurationSpace`.""" input_space = CS.ConfigurationSpace(seed=seed) for idx in range(n_continuous_params): @@ -58,9 +54,7 @@ def construct_parameter_space( ) ])) def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals - """ - Tests LlamaTune's low-to-high space projection method. - """ + """Tests LlamaTune's low-to-high space projection method.""" input_space = construct_parameter_space(**param_space_kwargs) # Number of target parameter space dimensions should be fewer than those of the original space @@ -107,8 +101,8 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N def test_special_parameter_values_validation() -> None: - """ - Tests LlamaTune's validation process of user-provided special parameter values dictionary. + """Tests LlamaTune's validation process of user-provided special parameter values + dictionary. """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( @@ -199,9 +193,7 @@ def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex - """ - Tests LlamaTune's special parameter values biasing methodology - """ + """Tests LlamaTune's special parameter values biasing methodology.""" input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) @@ -290,9 +282,7 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co def test_max_unique_values_per_param() -> None: - """ - Tests LlamaTune's parameter values discretization implementation. - """ + """Tests LlamaTune's parameter values discretization implementation.""" # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( @@ -346,8 +336,8 @@ def test_max_unique_values_per_param() -> None: ) ])) def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals - """ - Tests LlamaTune's approximate high-to-low space projection method, using pseudo-inverse. + """Tests LlamaTune's approximate high-to-low space projection method, using pseudo- + inverse. """ input_space = construct_parameter_space(**param_space_kwargs) @@ -403,9 +393,7 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: for max_unique_values_per_param in (50, 250) ])) def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: - """ - Tests LlamaTune space adapter when all components are active. - """ + """Tests LlamaTune space adapter when all components are active.""" # pylint: disable=too-many-locals # Define config space with a mix of different parameter types @@ -475,8 +463,8 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u ) ])) def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: - """ - Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. + """Tests LlamaTune's space adapter deterministic behavior when given same seed in + the input parameter space. """ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 5390f97c5f..fd22d0c257 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for space adapter factory. -""" +"""Tests for space adapter factory.""" # pylint: disable=missing-function-docstring @@ -29,9 +27,7 @@ *list(SpaceAdapterType), ]) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: - """ - Test that all optimizer types are listed in the ConcreteOptimizer constraints. - """ + """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" # pylint: disable=no-member assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] @@ -86,8 +82,6 @@ def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[S @pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: - """ - Test that all space adapter classes are listed in the SpaceAdapterType enum. - """ + """Test that all space adapter classes are listed in the SpaceAdapterType enum.""" space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index dee9251652..35a8f9ebb3 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Tests for mlos_core.spaces -""" +"""Tests for mlos_core.spaces.""" # pylint: disable=missing-function-docstring @@ -67,16 +65,12 @@ def test_is_log_uniform() -> None: def invalid_conversion_function(*args: Any) -> NoReturn: - """ - A quick dummy function for the base class to make pylint happy. - """ + """A quick dummy function for the base class to make pylint happy.""" raise NotImplementedError('subclass must override conversion_function') class BaseConversion(metaclass=ABCMeta): - """ - Base class for testing optimizer space conversions. - """ + """Base class for testing optimizer space conversions.""" conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod @@ -116,9 +110,7 @@ def categorical_counts(self, points: npt.NDArray) -> npt.NDArray: @abstractmethod def test_dimensionality(self) -> None: - """ - Check that the dimensionality of the converted space is correct. - """ + """Check that the dimensionality of the converted space is correct.""" def test_unsupported_hyperparameter(self) -> None: input_space = CS.ConfigurationSpace() @@ -175,9 +167,7 @@ def test_log_float_spaces(self) -> None: class TestFlamlConversion(BaseConversion): - """ - Tests for ConfigSpace to Flaml parameter conversions. - """ + """Tests for ConfigSpace to Flaml parameter conversions.""" conversion_function = staticmethod(configspace_to_flaml_space) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index df0e144535..f6933cbb6a 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Internal helper functions for mlos_core package. -""" +"""Internal helper functions for mlos_core package.""" from typing import Union @@ -13,7 +11,8 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: - """Converts a ConfigSpace config to a DataFrame + """ + Converts a ConfigSpace config to a DataFrame. Parameters ---------- diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index 2362de7083..61eb665064 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Version number for the mlos_core package. -""" +"""Version number for the mlos_core package.""" # NOTE: This should be managed by bumpversion. VERSION = '0.5.1' diff --git a/mlos_core/setup.py b/mlos_core/setup.py index fed376d1af..3771d73f43 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Setup instructions for the mlos_core package. -""" +"""Setup instructions for the mlos_core package.""" # pylint: disable=duplicate-code diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index 2390554e1e..a2a36b54a9 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -2,8 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -mlos_viz is a framework to help visualizing, explain, and gain insights from results +"""mlos_viz is a framework to help visualizing, explain, and gain insights from results from the mlos_bench framework for benchmarking and optimization automation. """ @@ -18,9 +17,7 @@ class MlosVizMethod(Enum): - """ - What method to use for visualizing the experiment results. - """ + """What method to use for visualizing the experiment results.""" DABL = "dabl" AUTO = DABL # use dabl as the current default diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 15358b0862..10b9946051 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Base functions for visualizing, explain, and gain insights from results. -""" +"""Base functions for visualizing, explain, and gain insights from results.""" import re import warnings @@ -37,8 +35,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: def ignore_plotter_warnings() -> None: - """ - Suppress some annoying warnings from third-party data visualization packages by + """Suppress some annoying warnings from third-party data visualization packages by adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) @@ -192,8 +189,8 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ - Utility function to process the results and determine the best performing - configs including potential repeats to help assess variability. + Utility function to process the results and determine the best performing configs + including potential repeats to help assess variability. Parameters ---------- diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 504486a58c..40deb848fd 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Small wrapper functions for dabl plotting functions via mlos_bench data. -""" +"""Small wrapper functions for dabl plotting functions via mlos_bench data.""" import warnings from typing import Dict, Literal, Optional @@ -39,9 +37,7 @@ def plot(exp_data: Optional[ExperimentData] = None, *, def ignore_plotter_warnings() -> None: - """ - Add some filters to ignore warnings from the plotter. - """ + """Add some filters to ignore warnings from the plotter.""" # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") diff --git a/mlos_viz/mlos_viz/tests/__init__.py b/mlos_viz/mlos_viz/tests/__init__.py index 2aa5f430cf..df64e0a313 100644 --- a/mlos_viz/mlos_viz/tests/__init__.py +++ b/mlos_viz/mlos_viz/tests/__init__.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mlos_viz. -""" +"""Unit tests for mlos_viz.""" import sys diff --git a/mlos_viz/mlos_viz/tests/conftest.py b/mlos_viz/mlos_viz/tests/conftest.py index ad29489e2c..228609ba09 100644 --- a/mlos_viz/mlos_viz/tests/conftest.py +++ b/mlos_viz/mlos_viz/tests/conftest.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Export test fixtures for mlos_viz. -""" +"""Export test fixtures for mlos_viz.""" from mlos_bench.tests import tunable_groups_fixtures from mlos_bench.tests.storage.sql import fixtures as sql_storage_fixtures diff --git a/mlos_viz/mlos_viz/tests/test_base_plot.py b/mlos_viz/mlos_viz/tests/test_base_plot.py index 52d571e742..1dc283c891 100644 --- a/mlos_viz/mlos_viz/tests/test_base_plot.py +++ b/mlos_viz/mlos_viz/tests/test_base_plot.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mlos_viz. -""" +"""Unit tests for mlos_viz.""" import warnings from unittest.mock import Mock, patch diff --git a/mlos_viz/mlos_viz/tests/test_dabl_plot.py b/mlos_viz/mlos_viz/tests/test_dabl_plot.py index fc4dd3667a..7fcee4dfe9 100644 --- a/mlos_viz/mlos_viz/tests/test_dabl_plot.py +++ b/mlos_viz/mlos_viz/tests/test_dabl_plot.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mlos_viz.dabl.plot. -""" +"""Unit tests for mlos_viz.dabl.plot.""" import warnings from unittest.mock import Mock, patch diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index 06ac4a7664..ecd072c287 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Unit tests for mlos_viz. -""" +"""Unit tests for mlos_viz.""" import random import warnings diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index 744fe28648..deb5227bc3 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Utility functions for manipulating experiment results data. -""" +"""Utility functions for manipulating experiment results data.""" from typing import Dict, Literal, Optional, Tuple import pandas diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index 607c7cc014..1d10835cd0 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Version number for the mlos_viz package. -""" +"""Version number for the mlos_viz package.""" # NOTE: This should be managed by bumpversion. VERSION = '0.5.1' diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 98d12598e1..73fd0f3c66 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -2,9 +2,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. # -""" -Setup instructions for the mlos_viz package. -""" +"""Setup instructions for the mlos_viz package.""" # pylint: disable=duplicate-code From ec1ba4d3e647afc5337c0759abce936aa9ababdf Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 18:08:11 +0000 Subject: [PATCH 29/54] restore shorter line lengths change --- .editorconfig | 3 +++ .pylintrc | 2 +- pyproject.toml | 2 +- setup.cfg | 4 ++-- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.editorconfig b/.editorconfig index e984d47595..7e753174de 100644 --- a/.editorconfig +++ b/.editorconfig @@ -12,6 +12,9 @@ charset = utf-8 # Note: this is not currently supported by all editors or their editorconfig plugins. max_line_length = 132 +[*.py] +max_line_length = 99 + # Makefiles need tab indentation [{Makefile,*.mk}] indent_style = tab diff --git a/.pylintrc b/.pylintrc index e686070503..c6c512ecb7 100644 --- a/.pylintrc +++ b/.pylintrc @@ -35,7 +35,7 @@ load-plugins= [FORMAT] # Maximum number of characters on a single line. -max-line-length=132 +max-line-length=99 [MESSAGE CONTROL] disable= diff --git a/pyproject.toml b/pyproject.toml index 16484d0aba..6865bf3d71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 132 +line-length = 99 target-version = ["py38", "py39", "py310", "py311", "py312"] include = '\.pyi?$' diff --git a/setup.cfg b/setup.cfg index 88fd64a8e2..6f948f523a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ count = True ignore = E203,W503,W504 format = pylint # See Also: .editorconfig, .pylintrc -max-line-length = 132 +max-line-length = 99 show-source = True statistics = True @@ -25,7 +25,7 @@ match = .+(? Date: Mon, 8 Jul 2024 19:37:11 +0000 Subject: [PATCH 30/54] reformat with black and lots of manual fixups for errors or missing things (e.g., comments, docstrings) --- conftest.py | 21 +- doc/source/conf.py | 9 +- .../fio/scripts/local/process_fio_results.py | 28 +- .../scripts/local/generate_redis_config.py | 12 +- .../scripts/local/process_redis_results.py | 19 +- .../boot/scripts/local/create_new_grub_cfg.py | 10 +- .../scripts/local/generate_grub_config.py | 10 +- .../local/generate_kernel_config_script.py | 10 +- .../mlos_bench/config/schemas/__init__.py | 4 +- .../config/schemas/config_schemas.py | 21 +- mlos_bench/mlos_bench/dict_templater.py | 14 +- .../mlos_bench/environments/__init__.py | 14 +- .../environments/base_environment.py | 110 +++-- .../mlos_bench/environments/composite_env.py | 55 ++- .../mlos_bench/environments/local/__init__.py | 4 +- .../environments/local/local_env.py | 111 +++-- .../environments/local/local_fileshare_env.py | 67 +-- .../mlos_bench/environments/mock_env.py | 39 +- .../environments/remote/__init__.py | 12 +- .../environments/remote/host_env.py | 29 +- .../environments/remote/network_env.py | 33 +- .../mlos_bench/environments/remote/os_env.py | 36 +- .../environments/remote/remote_env.py | 38 +- .../environments/remote/saas_env.py | 42 +- .../mlos_bench/environments/script_env.py | 40 +- mlos_bench/mlos_bench/event_loop_context.py | 2 +- mlos_bench/mlos_bench/launcher.py | 272 ++++++++---- mlos_bench/mlos_bench/optimizers/__init__.py | 8 +- .../mlos_bench/optimizers/base_optimizer.py | 81 ++-- .../optimizers/convert_configspace.py | 114 ++--- .../optimizers/grid_search_optimizer.py | 79 ++-- .../optimizers/mlos_core_optimizer.py | 96 +++-- .../mlos_bench/optimizers/mock_optimizer.py | 26 +- .../optimizers/one_shot_optimizer.py | 12 +- .../optimizers/track_best_optimizer.py | 26 +- mlos_bench/mlos_bench/os_environ.py | 11 +- mlos_bench/mlos_bench/run.py | 5 +- mlos_bench/mlos_bench/schedulers/__init__.py | 4 +- .../mlos_bench/schedulers/base_scheduler.py | 94 +++-- .../mlos_bench/schedulers/sync_scheduler.py | 4 +- mlos_bench/mlos_bench/services/__init__.py | 6 +- .../mlos_bench/services/base_fileshare.py | 43 +- .../mlos_bench/services/base_service.py | 56 ++- .../mlos_bench/services/config_persistence.py | 285 ++++++++----- .../mlos_bench/services/local/__init__.py | 2 +- .../mlos_bench/services/local/local_exec.py | 45 +- .../services/local/temp_dir_context.py | 19 +- .../services/remote/azure/__init__.py | 10 +- .../services/remote/azure/azure_auth.py | 40 +- .../remote/azure/azure_deployment_services.py | 141 ++++--- .../services/remote/azure/azure_fileshare.py | 31 +- .../remote/azure/azure_network_services.py | 74 ++-- .../services/remote/azure/azure_saas.py | 111 ++--- .../remote/azure/azure_vm_services.py | 252 ++++++----- .../services/remote/ssh/ssh_fileshare.py | 44 +- .../services/remote/ssh/ssh_host_service.py | 95 +++-- .../services/remote/ssh/ssh_service.py | 138 +++--- .../mlos_bench/services/types/__init__.py | 16 +- .../services/types/config_loader_type.py | 43 +- .../services/types/fileshare_type.py | 8 +- .../services/types/host_provisioner_type.py | 3 +- .../services/types/local_exec_type.py | 13 +- .../types/network_provisioner_type.py | 7 +- .../services/types/remote_config_type.py | 3 +- .../services/types/remote_exec_type.py | 5 +- mlos_bench/mlos_bench/storage/__init__.py | 4 +- .../storage/base_experiment_data.py | 22 +- mlos_bench/mlos_bench/storage/base_storage.py | 115 +++-- .../mlos_bench/storage/base_trial_data.py | 22 +- .../storage/base_tunable_config_data.py | 3 +- .../base_tunable_config_trial_group_data.py | 16 +- mlos_bench/mlos_bench/storage/sql/__init__.py | 2 +- mlos_bench/mlos_bench/storage/sql/common.py | 236 +++++++---- .../mlos_bench/storage/sql/experiment.py | 253 ++++++----- .../mlos_bench/storage/sql/experiment_data.py | 106 +++-- mlos_bench/mlos_bench/storage/sql/schema.py | 48 ++- mlos_bench/mlos_bench/storage/sql/storage.py | 26 +- mlos_bench/mlos_bench/storage/sql/trial.py | 114 ++--- .../mlos_bench/storage/sql/trial_data.py | 74 ++-- .../storage/sql/tunable_config_data.py | 14 +- .../sql/tunable_config_trial_group_data.py | 41 +- .../mlos_bench/storage/storage_factory.py | 8 +- mlos_bench/mlos_bench/storage/util.py | 18 +- mlos_bench/mlos_bench/tests/__init__.py | 37 +- .../mlos_bench/tests/config/__init__.py | 8 +- .../cli/test_load_cli_config_examples.py | 56 ++- .../mlos_bench/tests/config/conftest.py | 14 +- .../test_load_environment_config_examples.py | 64 ++- .../test_load_global_config_examples.py | 28 +- .../test_load_optimizer_config_examples.py | 8 +- .../tests/config/schemas/__init__.py | 56 ++- .../config/schemas/cli/test_cli_schemas.py | 15 +- .../environments/test_environment_schemas.py | 32 +- .../schemas/globals/test_globals_schemas.py | 6 +- .../optimizers/test_optimizer_schemas.py | 64 ++- .../schedulers/test_scheduler_schemas.py | 26 +- .../schemas/services/test_services_schemas.py | 35 +- .../schemas/storage/test_storage_schemas.py | 36 +- .../test_tunable_params_schemas.py | 1 + .../test_tunable_values_schemas.py | 6 +- .../test_load_service_config_examples.py | 14 +- .../test_load_storage_config_examples.py | 8 +- mlos_bench/mlos_bench/tests/conftest.py | 16 +- .../mlos_bench/tests/environments/__init__.py | 14 +- .../tests/environments/base_env_test.py | 10 +- .../composite_env_service_test.py | 22 +- .../tests/environments/composite_env_test.py | 143 +++---- .../environments/include_tunables_test.py | 40 +- .../tests/environments/local/__init__.py | 20 +- .../local/composite_local_env_test.py | 19 +- .../local/local_env_stdout_test.py | 88 ++-- .../local/local_env_telemetry_test.py | 148 ++++--- .../environments/local/local_env_test.py | 73 ++-- .../environments/local/local_env_vars_test.py | 57 +-- .../local/local_fileshare_env_test.py | 25 +- .../tests/environments/mock_env_test.py | 64 +-- .../tests/environments/remote/test_ssh_env.py | 18 +- .../tests/event_loop_context_test.py | 63 ++- .../tests/launcher_in_process_test.py | 40 +- .../tests/launcher_parse_args_test.py | 123 +++--- .../mlos_bench/tests/launcher_run_test.py | 93 +++-- .../mlos_bench/tests/optimizers/conftest.py | 48 +-- .../optimizers/grid_search_optimizer_test.py | 107 +++-- .../tests/optimizers/llamatune_opt_test.py | 5 +- .../tests/optimizers/mlos_core_opt_df_test.py | 68 +-- .../optimizers/mlos_core_opt_smac_test.py | 78 ++-- .../tests/optimizers/mock_opt_test.py | 67 +-- .../optimizers/opt_bulk_register_test.py | 101 +++-- .../optimizers/toy_optimization_loop_test.py | 16 +- .../mlos_bench/tests/services/__init__.py | 8 +- .../tests/services/config_persistence_test.py | 29 +- .../tests/services/local/__init__.py | 2 +- .../services/local/local_exec_python_test.py | 9 +- .../tests/services/local/local_exec_test.py | 120 +++--- .../tests/services/local/mock/__init__.py | 2 +- .../local/mock/mock_local_exec_service.py | 23 +- .../mlos_bench/tests/services/mock_service.py | 23 +- .../tests/services/remote/__init__.py | 6 +- .../remote/azure/azure_fileshare_test.py | 142 ++++--- .../azure/azure_network_services_test.py | 83 ++-- .../remote/azure/azure_vm_services_test.py | 210 ++++++---- .../tests/services/remote/azure/conftest.py | 99 +++-- .../services/remote/mock/mock_auth_service.py | 26 +- .../remote/mock/mock_fileshare_service.py | 25 +- .../remote/mock/mock_network_service.py | 35 +- .../remote/mock/mock_remote_exec_service.py | 26 +- .../services/remote/mock/mock_vm_service.py | 55 ++- .../tests/services/remote/ssh/__init__.py | 20 +- .../tests/services/remote/ssh/fixtures.py | 68 ++- .../services/remote/ssh/test_ssh_fileshare.py | 43 +- .../remote/ssh/test_ssh_host_service.py | 97 +++-- .../services/remote/ssh/test_ssh_service.py | 57 ++- .../mlos_bench/tests/storage/conftest.py | 4 +- .../mlos_bench/tests/storage/exp_data_test.py | 67 +-- .../mlos_bench/tests/storage/exp_load_test.py | 62 +-- .../mlos_bench/tests/storage/sql/fixtures.py | 86 ++-- .../tests/storage/trial_config_test.py | 10 +- .../tests/storage/trial_schedule_test.py | 22 +- .../tests/storage/trial_telemetry_test.py | 41 +- .../tests/storage/tunable_config_data_test.py | 21 +- .../tunable_config_trial_group_data_test.py | 38 +- .../mlos_bench/tests/test_with_alt_tz.py | 6 +- .../tests/tunable_groups_fixtures.py | 38 +- .../mlos_bench/tests/tunables/conftest.py | 47 ++- .../tunables/test_tunable_categoricals.py | 2 +- .../tunables/test_tunables_size_props.py | 26 +- .../tests/tunables/tunable_comparison_test.py | 15 +- .../tests/tunables/tunable_definition_test.py | 98 +++-- .../tunables/tunable_distributions_test.py | 68 ++- .../tunables/tunable_group_indexing_test.py | 4 +- .../tunables/tunable_group_subgroup_test.py | 2 +- .../tunable_to_configspace_distr_test.py | 54 +-- .../tunables/tunable_to_configspace_test.py | 59 ++- .../tests/tunables/tunables_assign_test.py | 26 +- .../tests/tunables/tunables_str_test.py | 76 ++-- mlos_bench/mlos_bench/tunables/__init__.py | 6 +- .../mlos_bench/tunables/covariant_group.py | 18 +- mlos_bench/mlos_bench/tunables/tunable.py | 57 ++- .../mlos_bench/tunables/tunable_groups.py | 58 ++- mlos_bench/mlos_bench/util.py | 42 +- mlos_bench/mlos_bench/version.py | 2 +- mlos_bench/setup.py | 81 ++-- mlos_core/mlos_core/optimizers/__init__.py | 32 +- .../bayesian_optimizers/__init__.py | 4 +- .../bayesian_optimizers/bayesian_optimizer.py | 20 +- .../bayesian_optimizers/smac_optimizer.py | 185 +++++--- .../mlos_core/optimizers/flaml_optimizer.py | 91 ++-- mlos_core/mlos_core/optimizers/optimizer.py | 128 ++++-- .../mlos_core/optimizers/random_optimizer.py | 33 +- .../mlos_core/spaces/adapters/__init__.py | 24 +- .../mlos_core/spaces/adapters/adapter.py | 15 +- .../mlos_core/spaces/adapters/llamatune.py | 219 +++++++--- .../mlos_core/spaces/converters/flaml.py | 18 +- mlos_core/mlos_core/tests/__init__.py | 19 +- .../optimizers/bayesian_optimizers_test.py | 23 +- .../mlos_core/tests/optimizers/conftest.py | 6 +- .../tests/optimizers/one_hot_test.py | 77 ++-- .../optimizers/optimizer_multiobj_test.py | 78 ++-- .../tests/optimizers/optimizer_test.py | 226 ++++++---- .../spaces/adapters/identity_adapter_test.py | 25 +- .../tests/spaces/adapters/llamatune_test.py | 394 +++++++++++------- .../adapters/space_adapter_factory_test.py | 64 +-- .../mlos_core/tests/spaces/spaces_test.py | 45 +- mlos_core/mlos_core/util.py | 10 +- mlos_core/mlos_core/version.py | 2 +- mlos_core/setup.py | 59 +-- mlos_viz/mlos_viz/__init__.py | 19 +- mlos_viz/mlos_viz/base.py | 223 ++++++---- mlos_viz/mlos_viz/dabl.py | 62 ++- mlos_viz/mlos_viz/tests/test_mlos_viz.py | 4 +- mlos_viz/mlos_viz/util.py | 10 +- mlos_viz/mlos_viz/version.py | 2 +- mlos_viz/setup.py | 47 ++- 213 files changed, 6612 insertions(+), 4210 deletions(-) diff --git a/conftest.py b/conftest.py index e22395f82f..7985ef8239 100644 --- a/conftest.py +++ b/conftest.py @@ -32,14 +32,23 @@ def pytest_configure(config: pytest.Config) -> None: Add some additional (global) configuration steps for pytest. """ # Workaround some issues loading emukit in certain environments. - if os.environ.get('DISPLAY', None): + if os.environ.get("DISPLAY", None): try: - import matplotlib # pylint: disable=import-outside-toplevel - matplotlib.rcParams['backend'] = 'agg' - if is_master(config) or dict(getattr(config, 'workerinput', {}))['workerid'] == 'gw0': + import matplotlib # pylint: disable=import-outside-toplevel + + matplotlib.rcParams["backend"] = "agg" + if is_master(config) or dict(getattr(config, "workerinput", {}))["workerid"] == "gw0": # Only warn once. - warn(UserWarning('DISPLAY environment variable is set, which can cause problems in some setups (e.g. WSL). ' - + f'Adjusting matplotlib backend to "{matplotlib.rcParams["backend"]}" to compensate.')) + warn( + UserWarning( + ( + "DISPLAY environment variable is set, " + "which can cause problems in some setups (e.g. WSL). " + f"Adjusting matplotlib backend to '{matplotlib.rcParams['backend']}' " + "to compensate." + ) + ) + ) except ImportError: pass diff --git a/doc/source/conf.py b/doc/source/conf.py index 3e25d9b082..a06436ba1f 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -87,10 +87,11 @@ numpydoc_class_members_toctree = False autodoc_default_options = { - 'members': True, - 'undoc-members': True, - # Don't generate documentation for some (non-private) functions that are more for internal implementation use. - 'exclude-members': 'mlos_bench.util.check_required_params' + "members": True, + "undoc-members": True, + # Don't generate documentation for some (non-private) functions that are more + # for internal implementation use. + "exclude-members": "mlos_bench.util.check_required_params", } # Generate the plots for the gallery diff --git a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py index a6d2d31df6..43baeb1cf8 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/fio/scripts/local/process_fio_results.py @@ -16,7 +16,7 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: """Flatten every dict in the hierarchy and rename the keys with the dict path.""" if isinstance(data, dict): - for (key, val) in data.items(): + for key, val in data.items(): yield from _flat_dict(val, f"{path}.{key}") else: yield (path, data) @@ -24,13 +24,15 @@ def _flat_dict(data: Any, path: str) -> Iterator[Tuple[str, Any]]: def _main(input_file: str, output_file: str, prefix: str) -> None: """Convert FIO read data from JSON to tall CSV.""" - with open(input_file, mode='r', encoding='utf-8') as fh_input: + with open(input_file, mode="r", encoding="utf-8") as fh_input: json_data = json.load(fh_input) - data = list(itertools.chain( - _flat_dict(json_data["jobs"][0], prefix), - _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util") - )) + data = list( + itertools.chain( + _flat_dict(json_data["jobs"][0], prefix), + _flat_dict(json_data["disk_util"][0], f"{prefix}.disk_util"), + ) + ) tall_df = pandas.DataFrame(data, columns=["metric", "value"]) tall_df.to_csv(output_file, index=False) @@ -43,12 +45,18 @@ def _main(input_file: str, output_file: str, prefix: str) -> None: parser = argparse.ArgumentParser(description="Post-process FIO benchmark results.") parser.add_argument( - "input", help="FIO benchmark results in JSON format (downloaded from a remote VM).") + "input", + help="FIO benchmark results in JSON format (downloaded from a remote VM).", + ) parser.add_argument( - "output", help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).") + "output", + help="Converted FIO benchmark data (CSV, to be consumed by mlos_bench).", + ) parser.add_argument( - "--prefix", default="fio", - help="Prefix of the metric IDs (default 'fio')") + "--prefix", + default="fio", + help="Prefix of the metric IDs (default 'fio')", + ) args = parser.parse_args() _main(args.input, args.output, args.prefix) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py index 949b9f9d91..d41f20d2a9 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/generate_redis_config.py @@ -14,17 +14,19 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): - line = f'{key} {val}' + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): + line = f"{key} {val}" fh_config.write(line + "\n") print(line) if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate Redis config from tunable parameters JSON.") + description="generate Redis config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output Redis config file.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index 08cfe57faa..d7f35f3d17 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -17,18 +17,19 @@ def _main(input_file: str, output_file: str) -> None: # Format the results from wide to long # The target is columns of metric and value to act as key-value pairs. df_long = ( - df_wide - .melt(id_vars=["test"]) + df_wide.melt(id_vars=["test"]) .assign(metric=lambda df: df["test"] + "_" + df["variable"]) .drop(columns=["test", "variable"]) .loc[:, ["metric", "value"]] ) # Add a default `score` metric to the end of the dataframe. - df_long = pd.concat([ - df_long, - pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}) - ]) + df_long = pd.concat( + [ + df_long, + pd.DataFrame({"metric": ["score"], "value": [df_long.value[df_long.index.max()]]}), + ] + ) df_long.to_csv(output_file, index=False) print(f"Converted: {input_file} -> {output_file}") @@ -38,7 +39,9 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") - parser.add_argument("output", help="Converted Redis benchmark data" + - " (to be consumed by OS Autotune framework).") + parser.add_argument( + "output", + help="Converted Redis benchmark data (to be consumed by OS Autotune framework).", + ) args = parser.parse_args() _main(args.input, args.output) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py index 47ed159c5a..9b75f04008 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/create_new_grub_cfg.py @@ -14,8 +14,10 @@ JSON_CONFIG_FILE = "config-boot-time.json" NEW_CFG = "zz-mlos-boot-params.cfg" -with open(JSON_CONFIG_FILE, 'r', encoding='UTF-8') as fh_json, \ - open(NEW_CFG, 'w', encoding='UTF-8') as fh_config: +with open(JSON_CONFIG_FILE, "r", encoding="UTF-8") as fh_json, open( + NEW_CFG, "w", encoding="UTF-8" +) as fh_config: for key, val in json.load(fh_json).items(): - fh_config.write('GRUB_CMDLINE_LINUX_DEFAULT="$' - f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n') + fh_config.write( + 'GRUB_CMDLINE_LINUX_DEFAULT="$' f'{{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"\n' + ) diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py index de344d61fb..9f130e5c0e 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/boot/scripts/local/generate_grub_config.py @@ -14,9 +14,10 @@ def _main(fname_input: str, fname_output: str) -> None: - with open(fname_input, "rt", encoding="utf-8") as fh_tunables, \ - open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in json.load(fh_tunables).items(): + with open(fname_input, "rt", encoding="utf-8") as fh_tunables, open( + fname_output, "wt", encoding="utf-8", newline="" + ) as fh_config: + for key, val in json.load(fh_tunables).items(): line = f'GRUB_CMDLINE_LINUX_DEFAULT="${{GRUB_CMDLINE_LINUX_DEFAULT}} {key}={val}"' fh_config.write(line + "\n") print(line) @@ -24,7 +25,8 @@ def _main(fname_input: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate GRUB config from tunable parameters JSON.") + description="Generate GRUB config from tunable parameters JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("output", help="Output shell script to configure GRUB.") args = parser.parse_args() diff --git a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py index 85a49a1817..a4e5e5ccb6 100755 --- a/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py +++ b/mlos_bench/mlos_bench/config/environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py @@ -6,7 +6,10 @@ """ Helper script to generate a script to update kernel parameters from tunables JSON. -Run: `./generate_kernel_config_script.py ./kernel-params.json ./kernel-params-meta.json ./config-kernel.sh` +Run: + ./generate_kernel_config_script.py \ + ./kernel-params.json ./kernel-params-meta.json \ + ./config-kernel.sh """ import argparse @@ -22,7 +25,7 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: tunables_meta = json.load(fh_meta) with open(fname_output, "wt", encoding="utf-8", newline="") as fh_config: - for (key, val) in tunables_data.items(): + for key, val in tunables_data.items(): meta = tunables_meta.get(key, {}) name_prefix = meta.get("name_prefix", "") line = f'echo "{val}" > {name_prefix}{key}' @@ -33,7 +36,8 @@ def _main(fname_input: str, fname_meta: str, fname_output: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="generate a script to update kernel parameters from tunables JSON.") + description="generate a script to update kernel parameters from tunables JSON." + ) parser.add_argument("input", help="JSON file with tunable parameters.") parser.add_argument("meta", help="JSON file with tunable parameters metadata.") diff --git a/mlos_bench/mlos_bench/config/schemas/__init__.py b/mlos_bench/mlos_bench/config/schemas/__init__.py index 05756f59bf..d4987add63 100644 --- a/mlos_bench/mlos_bench/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/config/schemas/__init__.py @@ -7,6 +7,6 @@ from mlos_bench.config.schemas.config_schemas import CONFIG_SCHEMA_DIR, ConfigSchema __all__ = [ - 'ConfigSchema', - 'CONFIG_SCHEMA_DIR', + "ConfigSchema", + "CONFIG_SCHEMA_DIR", ] diff --git a/mlos_bench/mlos_bench/config/schemas/config_schemas.py b/mlos_bench/mlos_bench/config/schemas/config_schemas.py index bfba5ed8a6..b7ce402b5d 100644 --- a/mlos_bench/mlos_bench/config/schemas/config_schemas.py +++ b/mlos_bench/mlos_bench/config/schemas/config_schemas.py @@ -27,9 +27,14 @@ # It is used in `ConfigSchema.validate()` method below. # NOTE: this may cause pytest to fail if it's expecting exceptions # to be raised for invalid configs. -_VALIDATION_ENV_FLAG = 'MLOS_BENCH_SKIP_SCHEMA_VALIDATION' -_SKIP_VALIDATION = (environ.get(_VALIDATION_ENV_FLAG, 'false').lower() - in {'true', 'y', 'yes', 'on', '1'}) +_VALIDATION_ENV_FLAG = "MLOS_BENCH_SKIP_SCHEMA_VALIDATION" +_SKIP_VALIDATION = environ.get(_VALIDATION_ENV_FLAG, "false").lower() in { + "true", + "y", + "yes", + "on", + "1", +} # Note: we separate out the SchemaStore from a class method on ConfigSchema @@ -84,10 +89,12 @@ def _load_registry(cls) -> None: """ if not cls._SCHEMA_STORE: cls._load_schemas() - cls._REGISTRY = Registry().with_resources([ - (url, Resource.from_contents(schema, default_specification=DRAFT202012)) - for url, schema in cls._SCHEMA_STORE.items() - ]) + cls._REGISTRY = Registry().with_resources( + [ + (url, Resource.from_contents(schema, default_specification=DRAFT202012)) + for url, schema in cls._SCHEMA_STORE.items() + ] + ) @property def registry(self) -> Registry: diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 2243bec7a4..576bc175c3 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -11,7 +11,7 @@ from mlos_bench.os_environ import environ -class DictTemplater: # pylint: disable=too-few-public-methods +class DictTemplater: # pylint: disable=too-few-public-methods """Simple class to help with nested dictionary $var templating.""" def __init__(self, source_dict: Dict[str, Any]): @@ -28,9 +28,9 @@ def __init__(self, source_dict: Dict[str, Any]): # The source/target dictionary to expand. self._dict: Dict[str, Any] = {} - def expand_vars(self, *, - extra_source_dict: Optional[Dict[str, Any]] = None, - use_os_env: bool = False) -> Dict[str, Any]: + def expand_vars( + self, *, extra_source_dict: Optional[Dict[str, Any]] = None, use_os_env: bool = False + ) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. @@ -51,7 +51,9 @@ def expand_vars(self, *, assert isinstance(self._dict, dict) return self._dict - def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool) -> Any: + def _expand_vars( + self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool + ) -> Any: """Recursively expand $var strings in the currently operating dictionary.""" if isinstance(value, str): # First try to expand all $vars internally. @@ -65,7 +67,7 @@ def _expand_vars(self, value: Any, extra_source_dict: Optional[Dict[str, Any]], elif isinstance(value, dict): # Note: we use a loop instead of dict comprehension in order to # allow secondary expansion of subsequent values immediately. - for (key, val) in value.items(): + for key, val in value.items(): value[key] = self._expand_vars(val, extra_source_dict, use_os_env) elif isinstance(value, list): value = [self._expand_vars(val, extra_source_dict, use_os_env) for val in value] diff --git a/mlos_bench/mlos_bench/environments/__init__.py b/mlos_bench/mlos_bench/environments/__init__.py index 8a4df5a5b2..ff649af50e 100644 --- a/mlos_bench/mlos_bench/environments/__init__.py +++ b/mlos_bench/mlos_bench/environments/__init__.py @@ -13,11 +13,11 @@ from mlos_bench.environments.status import Status __all__ = [ - 'Status', - 'Environment', - 'MockEnv', - 'RemoteEnv', - 'LocalEnv', - 'LocalFileShareEnv', - 'CompositeEnv', + "Status", + "Environment", + "MockEnv", + "RemoteEnv", + "LocalEnv", + "LocalFileShareEnv", + "CompositeEnv", ] diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index d91bb57041..0c3300aa10 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -44,15 +44,16 @@ class Environment(metaclass=abc.ABCMeta): """An abstract base of all benchmark environments.""" @classmethod - def new(cls, - *, - env_name: str, - class_name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None, - ) -> "Environment": + def new( + cls, + *, + env_name: str, + class_name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -90,16 +91,18 @@ def new(cls, config=config, global_config=global_config, tunables=tunables, - service=service + service=service, ) - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -130,24 +133,32 @@ def __init__(self, self._const_args: Dict[str, TunableValue] = config.get("const_args", {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Environment: '%s' Service: %s", name, - self._service.pprint() if self._service else None) + _LOG.debug( + "Environment: '%s' Service: %s", + name, + self._service.pprint() if self._service else None, + ) if tunables is None: - _LOG.warning("No tunables provided for %s. Tunable inheritance across composite environments may be broken.", name) + _LOG.warning( + ( + "No tunables provided for %s. " + "Tunable inheritance across composite environments may be broken." + ), + name, + ) tunables = TunableGroups() groups = self._expand_groups( - config.get("tunable_params", []), - (global_config or {}).get("tunable_params_map", {})) + config.get("tunable_params", []), (global_config or {}).get("tunable_params_map", {}) + ) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) self._tunable_params = tunables.subgroup(groups) # If a parameter comes from the tunables, do not require it in the const_args or globals - req_args = ( - set(config.get("required_args", [])) - - set(self._tunable_params.get_param_values().keys()) + req_args = set(config.get("required_args", [])) - set( + self._tunable_params.get_param_values().keys() ) merge_parameters(dest=self._const_args, source=global_config, required_keys=req_args) self._const_args = self._expand_vars(self._const_args, global_config or {}) @@ -156,8 +167,7 @@ def __init__(self, _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", - name, json.dumps(self.config, indent=2)) + _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) def _validate_json_config(self, config: dict, name: str) -> None: """Reconstructs a basic json config that this class might have been instantiated @@ -174,8 +184,9 @@ def _validate_json_config(self, config: dict, name: str) -> None: ConfigSchema.ENVIRONMENT.validate(json_config) @staticmethod - def _expand_groups(groups: Iterable[str], - groups_exp: Dict[str, Union[str, Sequence[str]]]) -> List[str]: + def _expand_groups( + groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] + ) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -197,7 +208,12 @@ def _expand_groups(groups: Iterable[str], if grp[:1] == "$": tunable_group_name = grp[1:] if tunable_group_name not in groups_exp: - raise KeyError(f"Expected tunable group name ${tunable_group_name} undefined in {groups_exp}") + raise KeyError( + ( + f"Expected tunable group name ${tunable_group_name} " + "undefined in {groups_exp}" + ) + ) add_groups = groups_exp[tunable_group_name] res += [add_groups] if isinstance(add_groups, str) else add_groups else: @@ -205,7 +221,9 @@ def _expand_groups(groups: Iterable[str], return res @staticmethod - def _expand_vars(params: Dict[str, TunableValue], global_config: Dict[str, TunableValue]) -> dict: + def _expand_vars( + params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] + ) -> dict: """Expand `$var` into actual values of the variables.""" return DictTemplater(params).expand_vars(extra_source_dict=global_config) @@ -214,7 +232,7 @@ def _config_loader_service(self) -> "SupportsConfigLoading": assert self._service is not None return self._service.config_loader_service - def __enter__(self) -> 'Environment': + def __enter__(self) -> "Environment": """Enter the environment's benchmarking context.""" _LOG.debug("Environment START :: %s", self) assert not self._in_context @@ -223,9 +241,12 @@ def __enter__(self) -> 'Environment': self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """Exit the context of the benchmarking environment.""" ex_throw = None if ex_val is None: @@ -293,7 +314,8 @@ def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ return tunables.get_param_values( group_names=list(self._tunable_params.get_covariant_group_names()), - into_params=self._const_args.copy()) + into_params=self._const_args.copy(), + ) @property def tunable_params(self) -> TunableGroups: @@ -317,7 +339,8 @@ def parameters(self) -> Dict[str, TunableValue]: Returns ------- parameters : Dict[str, TunableValue] - Key/value pairs of all environment parameters (i.e., `const_args` and `tunable_params`). + Key/value pairs of all environment parameters + (i.e., `const_args` and `tunable_params`). """ return self._params @@ -354,10 +377,15 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - # (Derived classes still have to check `self._tunable_params.is_updated()`). is_updated = self._tunable_params.is_updated() if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Env '%s': Tunable groups reset = %s :: %s", self, is_updated, { - name: self._tunable_params.is_updated([name]) - for name in self._tunable_params.get_covariant_group_names() - }) + _LOG.debug( + "Env '%s': Tunable groups reset = %s :: %s", + self, + is_updated, + { + name: self._tunable_params.is_updated([name]) + for name in self._tunable_params.get_covariant_group_names() + }, + ) else: _LOG.info("Env '%s': Tunable groups reset = %s", self, is_updated) diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 4bf38a5ef2..6f8961ce06 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -23,13 +23,15 @@ class CompositeEnv(Environment): """Composite benchmark environment.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment with a given config. @@ -49,8 +51,13 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # By default, the Environment includes only the tunables explicitly specified # in the "tunable_params" section of the config. `CompositeEnv`, however, must @@ -66,17 +73,19 @@ def __init__(self, # each CompositeEnv gets a copy of the original global config and adjusts it with # the `const_args` specific to it. global_config = (global_config or {}).copy() - for (key, val) in self._const_args.items(): + for key, val in self._const_args.items(): global_config.setdefault(key, val) for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, tunables, global_config, self._const_args, self._service): + child_config_file, tunables, global_config, self._const_args, self._service + ): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service) + child_config, tunables, global_config, self._const_args, self._service + ) self._add_child(env, tunables) _LOG.debug("Build composite environment '%s' END: %s", self, self._tunable_params) @@ -88,9 +97,12 @@ def __enter__(self) -> Environment: self._child_contexts = [env.__enter__() for env in self._children] return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: ex_throw = None for env in reversed(self._children): try: @@ -126,8 +138,11 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: pretty : str Pretty-printed environment configuration. """ - return super().pprint(indent, level) + '\n' + '\n'.join( - child.pprint(indent, level + 1) for child in self._children) + return ( + super().pprint(indent, level) + + "\n" + + "\n".join(child.pprint(indent, level + 1) for child in self._children) + ) def _add_child(self, env: Environment, tunables: TunableGroups) -> None: """ @@ -160,7 +175,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - """ assert self._in_context self._is_ready = super().setup(tunables, global_config) and all( - env_context.setup(tunables, global_config) for env_context in self._child_contexts) + env_context.setup(tunables, global_config) for env_context in self._child_contexts + ) return self._is_ready def teardown(self) -> None: @@ -234,5 +250,6 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: final_status = final_status or status _LOG.info("Final status: %s :: %s", self, final_status) - # Return the status and the timestamp of the last child environment or the first failed child environment. + # Return the status and the timestamp of the last child environment or the + # first failed child environment. return (final_status, timestamp, joint_telemetry) diff --git a/mlos_bench/mlos_bench/environments/local/__init__.py b/mlos_bench/mlos_bench/environments/local/__init__.py index 9a51941529..7de10647f1 100644 --- a/mlos_bench/mlos_bench/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/environments/local/__init__.py @@ -8,6 +8,6 @@ from mlos_bench.environments.local.local_fileshare_env import LocalFileShareEnv __all__ = [ - 'LocalEnv', - 'LocalFileShareEnv', + "LocalEnv", + "LocalFileShareEnv", ] diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index 8cb877a9d0..071827d364 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -32,13 +32,15 @@ class LocalEnv(ScriptEnv): # pylint: disable=too-many-instance-attributes """Scheduler-side Environment that runs scripts locally.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for local execution. @@ -61,11 +63,17 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service self._temp_dir: Optional[str] = None @@ -79,13 +87,18 @@ def __init__(self, def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None - self._temp_dir_context = self._local_exec_service.temp_dir_context(self.config.get("temp_dir")) + self._temp_dir_context = self._local_exec_service.temp_dir_context( + self.config.get("temp_dir") + ) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """Exit the context of the benchmarking environment.""" assert not (self._temp_dir is None or self._temp_dir_context is None) self._temp_dir_context.__exit__(ex_type, ex_val, ex_tb) @@ -131,10 +144,14 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - fname = path_join(self._temp_dir, self._dump_meta_file) _LOG.debug("Dump tunables metadata to file: %s", fname) with open(fname, "wt", encoding="utf-8") as fh_meta: - json.dump({ - tunable.name: tunable.meta - for (tunable, _group) in self._tunable_params if tunable.meta - }, fh_meta) + json.dump( + { + tunable.name: tunable.meta + for (tunable, _group) in self._tunable_params + if tunable.meta + }, + fh_meta, + ) if self._script_setup: (return_code, _output) = self._local_exec(self._script_setup, self._temp_dir) @@ -174,18 +191,28 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: _LOG.debug("Not reading the data at: %s", self) return (Status.SUCCEEDED, timestamp, stdout_data) - data = self._normalize_columns(pandas.read_csv( - self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir]), - index_col=False, - )) + data = self._normalize_columns( + pandas.read_csv( + self._config_loader_service.resolve_path( + self._read_results_file, extra_paths=[self._temp_dir] + ), + index_col=False, + ) + ) _LOG.debug("Read data:\n%s", data) if list(data.columns) == ["metric", "value"]: - _LOG.info("Local results have (metric,value) header and %d rows: assume long format", len(data)) + _LOG.info( + "Local results have (metric,value) header and %d rows: assume long format", + len(data), + ) data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) # Try to convert string metrics to numbers. - data = data.apply(pandas.to_numeric, errors='coerce').fillna(data) # type: ignore[assignment] # (false positive) + # type: ignore[assignment] # (false positive) + data = data.apply( # type: ignore[assignment] # (false positive) + pandas.to_numeric, + errors="coerce", + ).fillna(data) elif len(data) == 1: _LOG.info("Local results have 1 row: assume wide format") else: @@ -201,8 +228,8 @@ def _normalize_columns(data: pandas.DataFrame) -> pandas.DataFrame: # Windows cmd interpretation of > redirect symbols can leave trailing spaces in # the final column, which leads to misnamed columns. # For now, we simply strip trailing spaces from column names to account for that. - if sys.platform == 'win32': - data.rename(str.rstrip, axis='columns', inplace=True) + if sys.platform == "win32": + data.rename(str.rstrip, axis="columns", inplace=True) return data def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: @@ -214,24 +241,23 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir]) + self._read_telemetry_file, extra_paths=[self._temp_dir] + ) # TODO: Use the timestamp of the CSV file as our status timestamp? # FIXME: We should not be assuming that the only output file type is a CSV. - data = self._normalize_columns( - pandas.read_csv(fname, index_col=False)) + data = self._normalize_columns(pandas.read_csv(fname, index_col=False)) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") expected_col_names = ["timestamp", "metric", "value"] if len(data.columns) != len(expected_col_names): - raise ValueError(f'Telemetry data must have columns {expected_col_names}') + raise ValueError(f"Telemetry data must have columns {expected_col_names}") if list(data.columns) != expected_col_names: # Assume no header - this is ok for telemetry data. - data = pandas.read_csv( - fname, index_col=False, names=expected_col_names) + data = pandas.read_csv(fname, index_col=False, names=expected_col_names) data.iloc[:, 0] = datetime_parser(data.iloc[:, 0], origin="local") except FileNotFoundError as ex: @@ -240,10 +266,14 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: _LOG.debug("Read telemetry data:\n%s", data) col_dtypes: Mapping[int, Type] = {0: datetime} - return (status, timestamp, [ - (pandas.Timestamp(ts).to_pydatetime(), metric, value) - for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) - ]) + return ( + status, + timestamp, + [ + (pandas.Timestamp(ts).to_pydatetime(), metric, value) + for (ts, metric, value) in data.to_records(index=False, column_dtypes=col_dtypes) + ], + ) def teardown(self) -> None: """Clean up the local environment.""" @@ -273,7 +303,8 @@ def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd) + script, env=env_params, cwd=cwd + ) if return_code != 0: _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) return (return_code, {"stdout": stdout, "stderr": stderr}) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index fd6c2c1127..2996ea8cd2 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -27,13 +27,15 @@ class LocalFileShareEnv(LocalEnv): to the shared file storage. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new application environment with a given config. @@ -57,14 +59,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsLocalExec), \ - "LocalEnv requires a service that supports local execution" + assert self._service is not None and isinstance( + self._service, SupportsLocalExec + ), "LocalEnv requires a service that supports local execution" self._local_exec_service: SupportsLocalExec = self._service - assert self._service is not None and isinstance(self._service, SupportsFileShareOps), \ - "LocalEnv requires a service that supports file upload/download operations" + assert self._service is not None and isinstance( + self._service, SupportsFileShareOps + ), "LocalEnv requires a service that supports file upload/download operations" self._file_share_service: SupportsFileShareOps = self._service self._upload = self._template_from_to("upload") @@ -74,14 +84,12 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: """Convert a list of {"from": "...", "to": "..."} to a list of pairs of string.Template objects so that we can plug in self._params into it later. """ - return [ - (Template(d['from']), Template(d['to'])) - for d in self.config.get(config_key, []) - ] + return [(Template(d["from"]), Template(d["to"])) for d in self.config.get(config_key, [])] @staticmethod - def _expand(from_to: Iterable[Tuple[Template, Template]], - params: Mapping[str, TunableValue]) -> Generator[Tuple[str, str], None, None]: + def _expand( + from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] + ) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. @@ -117,9 +125,14 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._upload, params): - self._file_share_service.upload(self._params, self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir]), path_to) + for path_from, path_to in self._expand(self._upload, params): + self._file_share_service.upload( + self._params, + self._config_loader_service.resolve_path( + path_from, extra_paths=[self._temp_dir] + ), + path_to, + ) return self._is_ready def _download_files(self, ignore_missing: bool = False) -> None: @@ -135,11 +148,15 @@ def _download_files(self, ignore_missing: bool = False) -> None: assert self._temp_dir is not None params = self._get_env_params(restrict=False) params["PWD"] = self._temp_dir - for (path_from, path_to) in self._expand(self._download, params): + for path_from, path_to in self._expand(self._download, params): try: - self._file_share_service.download(self._params, - path_from, self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir])) + self._file_share_service.download( + self._params, + path_from, + self._config_loader_service.resolve_path( + path_to, extra_paths=[self._temp_dir] + ), + ) except FileNotFoundError as ex: _LOG.warning("Cannot download: %s", path_from) if not ignore_missing: diff --git a/mlos_bench/mlos_bench/environments/mock_env.py b/mlos_bench/mlos_bench/environments/mock_env.py index 16ff1195de..a8888c5e28 100644 --- a/mlos_bench/mlos_bench/environments/mock_env.py +++ b/mlos_bench/mlos_bench/environments/mock_env.py @@ -25,13 +25,15 @@ class MockEnv(Environment): _NOISE_VAR = 0.2 """Variance of the Gaussian noise added to the benchmark value.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment that produces mock benchmark data. @@ -51,8 +53,13 @@ def __init__(self, service: Service An optional service object. Not used by this class. """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) seed = int(self.config.get("mock_env_seed", -1)) self._random = random.Random(seed or None) if seed >= 0 else None self._range = self.config.get("mock_env_range") @@ -77,9 +84,9 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: return result # Simple convex function of all tunable parameters. - score = numpy.mean(numpy.square([ - self._normalized(tunable) for (tunable, _group) in self._tunable_params - ])) + score = numpy.mean( + numpy.square([self._normalized(tunable) for (tunable, _group) in self._tunable_params]) + ) # Add noise and shift the benchmark value from [0, 1] to a given range. noise = self._random.gauss(0, self._NOISE_VAR) if self._random else 0 @@ -98,11 +105,11 @@ def _normalized(tunable: Tunable) -> float: """ val = None if tunable.is_categorical: - val = (tunable.categories.index(tunable.category) / - float(len(tunable.categories) - 1)) + val = tunable.categories.index(tunable.category) / float(len(tunable.categories) - 1) elif tunable.is_numerical: - val = ((tunable.numerical_value - tunable.range[0]) / - float(tunable.range[1] - tunable.range[0])) + val = (tunable.numerical_value - tunable.range[0]) / float( + tunable.range[1] - tunable.range[0] + ) else: raise ValueError("Invalid parameter type: " + tunable.type) # Explicitly clip the value in case of numerical errors. diff --git a/mlos_bench/mlos_bench/environments/remote/__init__.py b/mlos_bench/mlos_bench/environments/remote/__init__.py index 3b26f8d6a7..10608a6980 100644 --- a/mlos_bench/mlos_bench/environments/remote/__init__.py +++ b/mlos_bench/mlos_bench/environments/remote/__init__.py @@ -12,10 +12,10 @@ from mlos_bench.environments.remote.vm_env import VMEnv __all__ = [ - 'HostEnv', - 'NetworkEnv', - 'OSEnv', - 'RemoteEnv', - 'SaaSEnv', - 'VMEnv', + "HostEnv", + "NetworkEnv", + "OSEnv", + "RemoteEnv", + "SaaSEnv", + "VMEnv", ] diff --git a/mlos_bench/mlos_bench/environments/remote/host_env.py b/mlos_bench/mlos_bench/environments/remote/host_env.py index ae88fa2197..c6d1c0145e 100644 --- a/mlos_bench/mlos_bench/environments/remote/host_env.py +++ b/mlos_bench/mlos_bench/environments/remote/host_env.py @@ -18,13 +18,15 @@ class HostEnv(Environment): """Remote host environment.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for host operations. @@ -45,10 +47,17 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM/host, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) - assert self._service is not None and isinstance(self._service, SupportsHostProvisioning), \ - "HostEnv requires a service that supports host provisioning operations" + assert self._service is not None and isinstance( + self._service, SupportsHostProvisioning + ), "HostEnv requires a service that supports host provisioning operations" self._host_service: SupportsHostProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index c3ad8ccd82..3f36345b58 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -25,13 +25,15 @@ class NetworkEnv(Environment): but no real tuning is expected for it ... yet. """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for network operations. @@ -52,14 +54,21 @@ def __init__(self, An optional service object (e.g., providing methods to deploy a network, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) # Virtual networks can be used for more than one experiment, so by default # we don't attempt to deprovision them. self._deprovision_on_teardown = config.get("deprovision_on_teardown", False) - assert self._service is not None and isinstance(self._service, SupportsNetworkProvisioning), \ - "NetworkEnv requires a service that supports network provisioning" + assert self._service is not None and isinstance( + self._service, SupportsNetworkProvisioning + ), "NetworkEnv requires a service that supports network provisioning" self._network_service: SupportsNetworkProvisioning = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -101,7 +110,9 @@ def teardown(self) -> None: return # Else _LOG.info("Network tear down: %s", self) - (status, params) = self._network_service.deprovision_network(self._params, ignore_errors=True) + (status, params) = self._network_service.deprovision_network( + self._params, ignore_errors=True + ) if status.is_pending(): (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) diff --git a/mlos_bench/mlos_bench/environments/remote/os_env.py b/mlos_bench/mlos_bench/environments/remote/os_env.py index 68a6f5fbe7..4328b8f694 100644 --- a/mlos_bench/mlos_bench/environments/remote/os_env.py +++ b/mlos_bench/mlos_bench/environments/remote/os_env.py @@ -20,13 +20,15 @@ class OSEnv(Environment): """OS Level Environment for a host.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -49,14 +51,22 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsOSOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsOSOps + ), "RemoteEnv requires a service that supports host operations" self._os_service: SupportsOSOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 94e789b198..87a76be45a 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -32,13 +32,15 @@ class RemoteEnv(ScriptEnv): e.g. Application Environment """ - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for remote execution. @@ -61,18 +63,25 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a Host, VM, OS, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._wait_boot = self.config.get("wait_boot", False) - assert self._service is not None and isinstance(self._service, SupportsRemoteExec), \ - "RemoteEnv requires a service that supports remote execution operations" + assert self._service is not None and isinstance( + self._service, SupportsRemoteExec + ), "RemoteEnv requires a service that supports remote execution operations" self._remote_exec_service: SupportsRemoteExec = self._service if self._wait_boot: - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -168,7 +177,8 @@ def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optiona env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params) + script, config=self._params, env_params=env_params + ) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: (status, output) = self._remote_exec_service.get_remote_exec_results(output) diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index 024430e22a..0b64bc679f 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -19,13 +19,15 @@ class SaaSEnv(Environment): """Cloud-based (configurable) SaaS environment.""" - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for (configurable) cloud-based SaaS instance. @@ -46,15 +48,22 @@ def __init__(self, An optional service object (e.g., providing methods to configure the remote service). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) - - assert self._service is not None and isinstance(self._service, SupportsHostOps), \ - "RemoteEnv requires a service that supports host operations" + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) + + assert self._service is not None and isinstance( + self._service, SupportsHostOps + ), "RemoteEnv requires a service that supports host operations" self._host_service: SupportsHostOps = self._service - assert self._service is not None and isinstance(self._service, SupportsRemoteConfig), \ - "SaaSEnv requires a service that supports remote host configuration API" + assert self._service is not None and isinstance( + self._service, SupportsRemoteConfig + ), "SaaSEnv requires a service that supports remote host configuration API" self._config_service: SupportsRemoteConfig = self._service def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) -> bool: @@ -80,7 +89,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values()) + self._params, self._tunable_params.get_param_values() + ) if not status.is_succeeded(): return False @@ -89,7 +99,7 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False # Azure Flex DB instances currently require a VM reboot after reconfiguration. - if res.get('isConfigPendingRestart') or res.get('isConfigPendingReboot'): + if res.get("isConfigPendingRestart") or res.get("isConfigPendingReboot"): _LOG.info("Restarting: %s", self) (status, params) = self._host_service.restart_host(self._params) if status.is_pending(): diff --git a/mlos_bench/mlos_bench/environments/script_env.py b/mlos_bench/mlos_bench/environments/script_env.py index 745430ca69..fe31d6fb13 100644 --- a/mlos_bench/mlos_bench/environments/script_env.py +++ b/mlos_bench/mlos_bench/environments/script_env.py @@ -23,13 +23,15 @@ class ScriptEnv(Environment, metaclass=abc.ABCMeta): _RE_INVALID = re.compile(r"[^a-zA-Z0-9_]") - def __init__(self, - *, - name: str, - config: dict, - global_config: Optional[dict] = None, - tunables: Optional[TunableGroups] = None, - service: Optional[Service] = None): + def __init__( + self, + *, + name: str, + config: dict, + global_config: Optional[dict] = None, + tunables: Optional[TunableGroups] = None, + service: Optional[Service] = None, + ): """ Create a new environment for script execution. @@ -59,19 +61,29 @@ def __init__(self, An optional service object (e.g., providing methods to deploy or reboot a VM, etc.). """ - super().__init__(name=name, config=config, global_config=global_config, - tunables=tunables, service=service) + super().__init__( + name=name, + config=config, + global_config=global_config, + tunables=tunables, + service=service, + ) self._script_setup = self.config.get("setup") self._script_run = self.config.get("run") self._script_teardown = self.config.get("teardown") self._shell_env_params: Iterable[str] = self.config.get("shell_env_params", []) - self._shell_env_params_rename: Dict[str, str] = self.config.get("shell_env_params_rename", {}) + self._shell_env_params_rename: Dict[str, str] = self.config.get( + "shell_env_params_rename", {} + ) results_stdout_pattern = self.config.get("results_stdout_pattern") - self._results_stdout_pattern: Optional[re.Pattern[str]] = \ - re.compile(results_stdout_pattern, flags=re.MULTILINE) if results_stdout_pattern else None + self._results_stdout_pattern: Optional[re.Pattern[str]] = ( + re.compile(results_stdout_pattern, flags=re.MULTILINE) + if results_stdout_pattern + else None + ) def _get_env_params(self, restrict: bool = True) -> Dict[str, str]: """ @@ -112,4 +124,6 @@ def _extract_stdout_results(self, stdout: str) -> Dict[str, TunableValue]: if not self._results_stdout_pattern: return {} _LOG.debug("Extract regex: '%s' from: '%s'", self._results_stdout_pattern, stdout) - return {key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout)} + return { + key: try_parse_val(val) for (key, val) in self._results_stdout_pattern.findall(stdout) + } diff --git a/mlos_bench/mlos_bench/event_loop_context.py b/mlos_bench/mlos_bench/event_loop_context.py index 8684844063..65285e5d66 100644 --- a/mlos_bench/mlos_bench/event_loop_context.py +++ b/mlos_bench/mlos_bench/event_loop_context.py @@ -18,7 +18,7 @@ else: from typing_extensions import TypeAlias -CoroReturnType = TypeVar('CoroReturnType') # pylint: disable=invalid-name +CoroReturnType = TypeVar("CoroReturnType") # pylint: disable=invalid-name if sys.version_info >= (3, 9): FutureReturnType: TypeAlias = Future[CoroReturnType] else: diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index 1a0caa6bba..106b853043 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -32,7 +32,7 @@ from mlos_bench.util import try_parse_val _LOG_LEVEL = logging.INFO -_LOG_FORMAT = '%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s' +_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" logging.basicConfig(level=_LOG_LEVEL, format=_LOG_FORMAT) _LOG = logging.getLogger(__name__) @@ -46,14 +46,16 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # pylint: disable=too-many-statements _LOG.info("Launch: %s", description) epilog = """ - Additional --key=value pairs can be specified to augment or override values listed in --globals. - Other required_args values can also be pulled from shell environment variables. + Additional --key=value pairs can be specified to augment or override + values listed in --globals. + Other required_args values can also be pulled from shell environment + variables. - For additional details, please see the website or the README.md files in the source tree: + For additional details, please see the website or the README.md files in + the source tree: """ - parser = argparse.ArgumentParser(description=f"{description} : {long_text}", - epilog=epilog) + parser = argparse.ArgumentParser(description=f"{description} : {long_text}", epilog=epilog) (args, args_rest) = self._parse_args(parser, argv) # Bootstrap config loader: command line takes priority. @@ -91,16 +93,18 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st args_rest, {key: val for (key, val) in config.items() if key not in vars(args)}, ) - # experiment_id is generally taken from --globals files, but we also allow overriding it on the CLI. + # experiment_id is generally taken from --globals files, but we also allow + # overriding it on the CLI. # It's useful to keep it there explicitly mostly for the --help output. if args.experiment_id: - self.global_config['experiment_id'] = args.experiment_id - # trial_config_repeat_count is a scheduler property but it's convenient to set it via command line + self.global_config["experiment_id"] = args.experiment_id + # trial_config_repeat_count is a scheduler property but it's convenient to + # set it via command line if args.trial_config_repeat_count: self.global_config["trial_config_repeat_count"] = args.trial_config_repeat_count # Ensure that the trial_id is present since it gets used by some other # configs but is typically controlled by the run optimize loop. - self.global_config.setdefault('trial_id', 1) + self.global_config.setdefault("trial_id", 1) self.global_config = DictTemplater(self.global_config).expand_vars(use_os_env=True) assert isinstance(self.global_config, dict) @@ -108,24 +112,29 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st # --service cli args should override the config file values. service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) - self._parent_service = self._parent_service.load_services(service_files, self.global_config, self._parent_service) + self._parent_service = self._parent_service.load_services( + service_files, self.global_config, self._parent_service + ) env_path = args.environment or config.get("environment") if not env_path: _LOG.error("No environment config specified.") - parser.error("At least the Environment config must be specified." + - " Run `mlos_bench --help` and consult `README.md` for more info.") + parser.error( + "At least the Environment config must be specified." + + " Run `mlos_bench --help` and consult `README.md` for more info." + ) self.root_env_config = self._config_loader.resolve_path(env_path) self.environment: Environment = self._config_loader.load_environment( - self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service) + self.root_env_config, TunableGroups(), self.global_config, service=self._parent_service + ) _LOG.info("Init environment: %s", self.environment) # NOTE: Init tunable values *after* the Environment, but *before* the Optimizer self.tunables = self._init_tunable_values( args.random_init or config.get("random_init", False), config.get("random_seed") if args.random_seed is None else args.random_seed, - config.get("tunable_values", []) + (args.tunable_values or []) + config.get("tunable_values", []) + (args.tunable_values or []), ) _LOG.info("Init tunables: %s", self.tunables) @@ -135,7 +144,11 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st self.storage = self._load_storage(args.storage or config.get("storage")) _LOG.info("Init storage: %s", self.storage) - self.teardown: bool = bool(args.teardown) if args.teardown is not None else bool(config.get("teardown", True)) + self.teardown: bool = ( + bool(args.teardown) + if args.teardown is not None + else bool(config.get("teardown", True)) + ) self.scheduler = self._load_scheduler(args.scheduler or config.get("scheduler")) _LOG.info("Init scheduler: %s", self.scheduler) @@ -150,85 +163,150 @@ def service(self) -> Service: return self._parent_service @staticmethod - def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> Tuple[argparse.Namespace, List[str]]: + def _parse_args( + parser: argparse.ArgumentParser, argv: Optional[List[str]] + ) -> Tuple[argparse.Namespace, List[str]]: """Parse the command line arguments.""" parser.add_argument( - '--config', required=False, - help='Main JSON5 configuration file. Its keys are the same as the' + - ' command line options and can be overridden by the latter.\n' + - '\n' + - ' See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ ' + - ' for additional config examples for this and other arguments.') + "--config", + required=False, + help="Main JSON5 configuration file. Its keys are the same as the" + + " command line options and can be overridden by the latter.\n" + + "\n" + + " See the `mlos_bench/config/` tree at https://github.com/microsoft/MLOS/ " + + " for additional config examples for this and other arguments.", + ) parser.add_argument( - '--log_file', '--log-file', required=False, - help='Path to the log file. Use stdout if omitted.') + "--log_file", + "--log-file", + required=False, + help="Path to the log file. Use stdout if omitted.", + ) parser.add_argument( - '--log_level', '--log-level', required=False, type=str, - help=f'Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}.' + - ' Set to DEBUG for debug, WARNING for warnings only.') + "--log_level", + "--log-level", + required=False, + type=str, + help=f"Logging level. Default is {logging.getLevelName(_LOG_LEVEL)}." + + " Set to DEBUG for debug, WARNING for warnings only.", + ) parser.add_argument( - '--config_path', '--config-path', '--config-paths', '--config_paths', - nargs="+", action='extend', required=False, - help='One or more locations of JSON config files.') + "--config_path", + "--config-path", + "--config-paths", + "--config_paths", + nargs="+", + action="extend", + required=False, + help="One or more locations of JSON config files.", + ) parser.add_argument( - '--service', '--services', - nargs='+', action='extend', required=False, - help='Path to JSON file with the configuration of the service(s) for environment(s) to use.') + "--service", + "--services", + nargs="+", + action="extend", + required=False, + help=( + "Path to JSON file with the configuration " + "of the service(s) for environment(s) to use." + ), + ) parser.add_argument( - '--environment', required=False, - help='Path to JSON file with the configuration of the benchmarking environment(s).') + "--environment", + required=False, + help="Path to JSON file with the configuration of the benchmarking environment(s).", + ) parser.add_argument( - '--optimizer', required=False, - help='Path to the optimizer configuration file. If omitted, run' + - ' a single trial with default (or specified in --tunable_values).') + "--optimizer", + required=False, + help="Path to the optimizer configuration file. If omitted, run" + + " a single trial with default (or specified in --tunable_values).", + ) parser.add_argument( - '--trial_config_repeat_count', '--trial-config-repeat-count', required=False, type=int, - help='Number of times to repeat each config. Default is 1 trial per config, though more may be advised.') + "--trial_config_repeat_count", + "--trial-config-repeat-count", + required=False, + type=int, + help=( + "Number of times to repeat each config. " + "Default is 1 trial per config, though more may be advised." + ), + ) parser.add_argument( - '--scheduler', required=False, - help='Path to the scheduler configuration file. By default, use' + - ' a single worker synchronous scheduler.') + "--scheduler", + required=False, + help="Path to the scheduler configuration file. By default, use" + + " a single worker synchronous scheduler.", + ) parser.add_argument( - '--storage', required=False, - help='Path to the storage configuration file.' + - ' If omitted, use the ephemeral in-memory SQL storage.') + "--storage", + required=False, + help="Path to the storage configuration file." + + " If omitted, use the ephemeral in-memory SQL storage.", + ) parser.add_argument( - '--random_init', '--random-init', required=False, default=False, - dest='random_init', action='store_true', - help='Initialize tunables with random values. (Before applying --tunable_values).') + "--random_init", + "--random-init", + required=False, + default=False, + dest="random_init", + action="store_true", + help="Initialize tunables with random values. (Before applying --tunable_values).", + ) parser.add_argument( - '--random_seed', '--random-seed', required=False, type=int, - help='Seed to use with --random_init') + "--random_seed", + "--random-seed", + required=False, + type=int, + help="Seed to use with --random_init", + ) parser.add_argument( - '--tunable_values', '--tunable-values', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain values of the tunable' + - ' parameters. This can be used for a single trial (when no --optimizer' + - ' is specified) or as default values for the first run in optimization.') + "--tunable_values", + "--tunable-values", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain values of the tunable" + + " parameters. This can be used for a single trial (when no --optimizer" + + " is specified) or as default values for the first run in optimization.", + ) parser.add_argument( - '--globals', nargs="+", action='extend', required=False, - help='Path to one or more JSON files that contain additional' + - ' [private] parameters of the benchmarking environment.') + "--globals", + nargs="+", + action="extend", + required=False, + help="Path to one or more JSON files that contain additional" + + " [private] parameters of the benchmarking environment.", + ) parser.add_argument( - '--no_teardown', '--no-teardown', required=False, default=None, - dest='teardown', action='store_false', - help='Disable teardown of the environment after the benchmark.') + "--no_teardown", + "--no-teardown", + required=False, + default=None, + dest="teardown", + action="store_false", + help="Disable teardown of the environment after the benchmark.", + ) parser.add_argument( - '--experiment_id', '--experiment-id', required=False, default=None, + "--experiment_id", + "--experiment-id", + required=False, + default=None, help=""" Experiment ID to use for the benchmark. If omitted, the value from the --cli config or --globals is used. @@ -238,7 +316,7 @@ def _parse_args(parser: argparse.ArgumentParser, argv: Optional[List[str]]) -> T changes are made to config files, scripts, versions, etc. This is left as a manual operation as detection of what is "incompatible" is not easily automatable across systems. - """ + """, ) # By default we use the command line arguments, but allow the caller to @@ -278,15 +356,17 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> Dict[str, TunableValue]: _LOG.debug("Parsed config: %s", config) return config - def _load_config(self, - args_globals: Iterable[str], - config_path: Iterable[str], - args_rest: Iterable[str], - global_config: Dict[str, Any]) -> Dict[str, Any]: + def _load_config( + self, + args_globals: Iterable[str], + config_path: Iterable[str], + args_rest: Iterable[str], + global_config: Dict[str, Any], + ) -> Dict[str, Any]: """Get key/value pairs of the global configuration parameters from the specified config files (if any) and command line arguments. """ - for config_file in (args_globals or []): + for config_file in args_globals or []: conf = self._config_loader.load_config(config_file, ConfigSchema.GLOBALS) assert isinstance(conf, dict) global_config.update(conf) @@ -295,8 +375,9 @@ def _load_config(self, global_config["config_path"] = config_path return global_config - def _init_tunable_values(self, random_init: bool, seed: Optional[int], - args_tunables: Optional[str]) -> TunableGroups: + def _init_tunable_values( + self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] + ) -> TunableGroups: """Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. """ @@ -305,8 +386,10 @@ def _init_tunable_values(self, random_init: bool, seed: Optional[int], if random_init: tunables = MockOptimizer( - tunables=tunables, service=None, - config={"start_with_defaults": False, "seed": seed}).suggest() + tunables=tunables, + service=None, + config={"start_with_defaults": False, "seed": seed}, + ).suggest() _LOG.debug("Init tunables: random = %s", tunables) if args_tunables is not None: @@ -329,15 +412,20 @@ def _load_optimizer(self, args_optimizer: Optional[str]) -> Optimizer: if args_optimizer is None: # global_config may contain additional properties, so we need to # strip those out before instantiating the basic oneshot optimizer. - config = {key: val for key, val in self.global_config.items() if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS} - return OneShotOptimizer( - self.tunables, config=config, service=self._parent_service) + config = { + key: val + for key, val in self.global_config.items() + if key in OneShotOptimizer.BASE_SUPPORTED_CONFIG_PROPS + } + return OneShotOptimizer(self.tunables, config=config, service=self._parent_service) class_config = self._config_loader.load_config(args_optimizer, ConfigSchema.OPTIMIZER) assert isinstance(class_config, Dict) - optimizer = self._config_loader.build_optimizer(tunables=self.tunables, - service=self._parent_service, - config=class_config, - global_config=self.global_config) + optimizer = self._config_loader.build_optimizer( + tunables=self.tunables, + service=self._parent_service, + config=class_config, + global_config=self.global_config, + ) return optimizer def _load_storage(self, args_storage: Optional[str]) -> Storage: @@ -350,17 +438,20 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: if args_storage is None: # pylint: disable=import-outside-toplevel from mlos_bench.storage.sql.storage import SqlStorage - return SqlStorage(service=self._parent_service, - config={ - "drivername": "sqlite", - "database": ":memory:", - "lazy_schema_create": True, - }) + + return SqlStorage( + service=self._parent_service, + config={ + "drivername": "sqlite", + "database": ":memory:", + "lazy_schema_create": True, + }, + ) class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) assert isinstance(class_config, Dict) - storage = self._config_loader.build_storage(service=self._parent_service, - config=class_config, - global_config=self.global_config) + storage = self._config_loader.build_storage( + service=self._parent_service, config=class_config, global_config=self.global_config + ) return storage def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: @@ -376,6 +467,7 @@ def _load_scheduler(self, args_scheduler: Optional[str]) -> Scheduler: if args_scheduler is None: # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.sync_scheduler import SyncScheduler + return SyncScheduler( # All config values can be overridden from global config config={ diff --git a/mlos_bench/mlos_bench/optimizers/__init__.py b/mlos_bench/mlos_bench/optimizers/__init__.py index 167fe022e6..7cd6a8a25a 100644 --- a/mlos_bench/mlos_bench/optimizers/__init__.py +++ b/mlos_bench/mlos_bench/optimizers/__init__.py @@ -10,8 +10,8 @@ from mlos_bench.optimizers.one_shot_optimizer import OneShotOptimizer __all__ = [ - 'Optimizer', - 'MockOptimizer', - 'OneShotOptimizer', - 'MlosCoreOptimizer', + "Optimizer", + "MockOptimizer", + "OneShotOptimizer", + "MlosCoreOptimizer", ] diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index e9b4ff8388..9cdecffc81 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -25,7 +25,7 @@ _LOG = logging.getLogger(__name__) -class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes +class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """An abstract interface between the benchmarking framework and mlos_core optimizers. """ @@ -38,11 +38,13 @@ class Optimizer(metaclass=ABCMeta): # pylint: disable=too-many-instance-attr "start_with_defaults", } - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new optimizer for the given configuration space defined by the tunables. @@ -67,19 +69,20 @@ def __init__(self, self._seed = int(config.get("seed", 42)) self._in_context = False - experiment_id = self._global_config.get('experiment_id') + experiment_id = self._global_config.get("experiment_id") self.experiment_id = str(experiment_id).strip() if experiment_id else None self._iter = 0 # If False, use the optimizer to suggest the initial configuration; # if True (default), use the already initialized values for the first iteration. self._start_with_defaults: bool = bool( - strtobool(str(self._config.pop('start_with_defaults', True)))) - self._max_iter = int(self._config.pop('max_suggestions', 100)) + strtobool(str(self._config.pop("start_with_defaults", True))) + ) + self._max_iter = int(self._config.pop("max_suggestions", 100)) - opt_targets: Dict[str, str] = self._config.pop('optimization_targets', {'score': 'min'}) + opt_targets: Dict[str, str] = self._config.pop("optimization_targets", {"score": "min"}) self._opt_targets: Dict[str, Literal[1, -1]] = {} - for (opt_target, opt_dir) in opt_targets.items(): + for opt_target, opt_dir in opt_targets.items(): if opt_dir == "min": self._opt_targets[opt_target] = 1 elif opt_dir == "max": @@ -106,16 +109,19 @@ def __repr__(self) -> str: ) return f"{self.name}({opt_targets},config={self._config})" - def __enter__(self) -> 'Optimizer': + def __enter__(self) -> "Optimizer": """Enter the optimizer's context.""" _LOG.debug("Optimizer START :: %s", self) assert not self._in_context self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """Exit the context of the optimizer.""" if ex_val is None: _LOG.debug("Optimizer END :: %s", self) @@ -199,7 +205,7 @@ def name(self) -> str: return self.__class__.__name__ @property - def targets(self) -> Dict[str, Literal['min', 'max']]: + def targets(self) -> Dict[str, Literal["min", "max"]]: """A dictionary of {target: direction} of optimization targets.""" return { opt_target: "min" if opt_dir == 1 else "max" @@ -214,10 +220,12 @@ def supports_preload(self) -> bool: return True @abstractmethod - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: """ Pre-load the optimizer with the bulk data from previous experiments. @@ -235,8 +243,12 @@ def bulk_register(self, is_not_empty : bool True if there is data to register, false otherwise. """ - _LOG.info("Update the optimizer with: %d configs, %d scores, %d status values", - len(configs or []), len(scores or []), len(status or [])) + _LOG.info( + "Update the optimizer with: %d configs, %d scores, %d status values", + len(configs or []), + len(scores or []), + len(status or []), + ) if len(configs or []) != len(scores or []): raise ValueError("Numbers of configs and scores do not match.") if status is not None and len(configs or []) != len(status or []): @@ -264,8 +276,12 @@ def suggest(self) -> TunableGroups: return self._tunables.copy() @abstractmethod - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: """ Register the observation for the given configuration. @@ -286,15 +302,16 @@ def register(self, tunables: TunableGroups, status: Status, Benchmark scores extracted (and possibly transformed) from the dataframe that's being MINIMIZED. """ - _LOG.info("Iteration %d :: Register: %s = %s score: %s", - self._iter, tunables, status, score) + _LOG.info( + "Iteration %d :: Register: %s = %s score: %s", self._iter, tunables, status, score + ) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) - def _get_scores(self, status: Status, - scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] - ) -> Optional[Dict[str, float]]: + def _get_scores( + self, status: Status, scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] + ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are maximizing. @@ -323,7 +340,7 @@ def _get_scores(self, status: Status, assert scores is not None target_metrics: Dict[str, float] = {} - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): val = scores[opt_target] assert val is not None target_metrics[opt_target] = float(val) * opt_dir @@ -339,7 +356,9 @@ def not_converged(self) -> bool: return self._iter < self._max_iter @abstractmethod - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: """ Get the best observation so far. diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index 3ab1c43ab9..f53e308352 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -47,7 +47,8 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0) -> ConfigurationSpace: + tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 +) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, wrapped in a ConfigurationSpace for composability. @@ -70,27 +71,28 @@ def _tunable_to_configspace( meta = {"group": group_name, "cost": cost} # {"scaling": ""} if tunable.type == "categorical": - return ConfigurationSpace({ - tunable.name: CategoricalHyperparameter( - name=tunable.name, - choices=tunable.categories, - weights=_normalize_weights(tunable.weights) if tunable.weights else None, - default_value=tunable.default, - meta=meta) - }) + return ConfigurationSpace( + { + tunable.name: CategoricalHyperparameter( + name=tunable.name, + choices=tunable.categories, + weights=_normalize_weights(tunable.weights) if tunable.weights else None, + default_value=tunable.default, + meta=meta, + ) + } + ) distribution: Union[Uniform, Normal, Beta, None] = None if tunable.distribution == "uniform": distribution = Uniform() elif tunable.distribution == "normal": distribution = Normal( - mu=tunable.distribution_params["mu"], - sigma=tunable.distribution_params["sigma"] + mu=tunable.distribution_params["mu"], sigma=tunable.distribution_params["sigma"] ) elif tunable.distribution == "beta": distribution = Beta( - alpha=tunable.distribution_params["alpha"], - beta=tunable.distribution_params["beta"] + alpha=tunable.distribution_params["alpha"], beta=tunable.distribution_params["beta"] ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -102,22 +104,26 @@ def _tunable_to_configspace( log=bool(tunable.is_log), q=nullable(int, tunable.quantization), distribution=distribution, - default=(int(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + int(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) elif tunable.type == "float": range_hp = Float( name=tunable.name, bounds=tunable.range, log=bool(tunable.is_log), - q=tunable.quantization, # type: ignore[arg-type] + q=tunable.quantization, # type: ignore[arg-type] distribution=distribution, # type: ignore[arg-type] - default=(float(tunable.default) - if tunable.in_range(tunable.default) and tunable.default is not None - else None), - meta=meta + default=( + float(tunable.default) + if tunable.in_range(tunable.default) and tunable.default is not None + else None + ), + meta=meta, ) else: raise TypeError(f"Invalid Parameter Type: {tunable.type}") @@ -135,31 +141,37 @@ def _tunable_to_configspace( # Create three hyperparameters: one for regular values, # one for special values, and one to choose between the two. (special_name, type_name) = special_param_names(tunable.name) - conf_space = ConfigurationSpace({ - tunable.name: range_hp, - special_name: CategoricalHyperparameter( - name=special_name, - choices=tunable.special, - weights=special_weights, - default_value=tunable.default if tunable.default in tunable.special else None, - meta=meta - ), - type_name: CategoricalHyperparameter( - name=type_name, - choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], - weights=switch_weights, - default_value=TunableValueKind.SPECIAL, - ), - }) - conf_space.add_condition(EqualsCondition( - conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL)) - conf_space.add_condition(EqualsCondition( - conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE)) + conf_space = ConfigurationSpace( + { + tunable.name: range_hp, + special_name: CategoricalHyperparameter( + name=special_name, + choices=tunable.special, + weights=special_weights, + default_value=tunable.default if tunable.default in tunable.special else None, + meta=meta, + ), + type_name: CategoricalHyperparameter( + name=type_name, + choices=[TunableValueKind.SPECIAL, TunableValueKind.RANGE], + weights=switch_weights, + default_value=TunableValueKind.SPECIAL, + ), + } + ) + conf_space.add_condition( + EqualsCondition(conf_space[special_name], conf_space[type_name], TunableValueKind.SPECIAL) + ) + conf_space.add_condition( + EqualsCondition(conf_space[tunable.name], conf_space[type_name], TunableValueKind.RANGE) + ) return conf_space -def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = None) -> ConfigurationSpace: +def tunable_groups_to_configspace( + tunables: TunableGroups, seed: Optional[int] = None +) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -177,11 +189,14 @@ def tunable_groups_to_configspace(tunables: TunableGroups, seed: Optional[int] = A new ConfigurationSpace instance that corresponds to the input TunableGroups. """ space = ConfigurationSpace(seed=seed) - for (tunable, group) in tunables: + for tunable, group in tunables: space.add_configuration_space( - prefix="", delimiter="", + prefix="", + delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost())) + tunable, group.name, group.get_current_cost() + ), + ) return space @@ -200,7 +215,7 @@ def tunable_values_to_configuration(tunables: TunableGroups) -> Configuration: A ConfigSpace Configuration. """ values: Dict[str, TunableValue] = {} - for (tunable, _group) in tunables: + for tunable, _group in tunables: if tunable.special: (special_name, type_name) = special_param_names(tunable.name) if tunable.value in tunable.special: @@ -222,10 +237,7 @@ def configspace_data_to_tunable_values(data: dict) -> Dict[str, TunableValue]: In particular, remove and keys suffixes added by `special_param_names`. """ data = data.copy() - specials = [ - special_param_name_strip(k) - for k in data.keys() if special_param_name_is_temp(k) - ] + specials = [special_param_name_strip(k) for k in data.keys() if special_param_name_is_temp(k)] for k in specials: (special_name, type_name) = special_param_names(k) if data[type_name] == TunableValueKind.SPECIAL: diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 9d90a58560..568cfff43f 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -24,11 +24,13 @@ class GridSearchOptimizer(TrackBestOptimizer): """Grid search optimizer.""" - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) # Track the grid as a set of tuples of tunable values and reconstruct the @@ -47,11 +49,19 @@ def __init__(self, def _sanity_check(self) -> None: size = np.prod([tunable.cardinality for (tunable, _group) in self._tunables]) if size == np.inf: - raise ValueError(f"Unquantized tunables are not supported for grid search: {self._tunables}") + raise ValueError( + f"Unquantized tunables are not supported for grid search: {self._tunables}" + ) if size > 10000: - _LOG.warning("Large number %d of config points requested for grid search: %s", size, self._tunables) + _LOG.warning( + "Large number %d of config points requested for grid search: %s", + size, + self._tunables, + ) if size > self._max_iter: - _LOG.warning("Grid search size %d, is greater than max iterations %d", size, self._max_iter) + _LOG.warning( + "Grid search size %d, is greater than max iterations %d", size, self._max_iter + ) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: """ @@ -64,12 +74,14 @@ def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], Non # names instead of the order given by TunableGroups. configs = [ configspace_data_to_tunable_values(dict(config)) - for config in - generate_grid(self.config_space, { - tunable.name: int(tunable.cardinality) - for (tunable, _group) in self._tunables - if tunable.quantization or tunable.type == "int" - }) + for config in generate_grid( + self.config_space, + { + tunable.name: int(tunable.cardinality) + for (tunable, _group) in self._tunables + if tunable.quantization or tunable.type == "int" + }, + ) ] names = set(tuple(configs.keys()) for configs in configs) assert len(names) == 1 @@ -99,15 +111,17 @@ def suggested_configs(self) -> Iterable[Dict[str, TunableValue]]: # See NOTEs above. return (dict(zip(self._config_keys, config)) for config in self._suggested_configs) - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -146,20 +160,35 @@ def suggest(self) -> TunableGroups: _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) try: - config = dict(ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values())) + config = dict( + ConfigSpace.Configuration(self.config_space, values=tunables.get_param_values()) + ) self._suggested_configs.remove(tuple(config.values())) except KeyError: - _LOG.warning("Attempted to remove missing config (previously registered?) from suggested set: %s", tunables) + _LOG.warning( + ( + "Attempted to remove missing config " + "(previously registered?) from suggested set: %s" + ), + tunables, + ) return registered_score def not_converged(self) -> bool: if self._iter > self._max_iter: if bool(self._pending_configs): - _LOG.warning("Exceeded max iterations, but still have %d pending configs: %s", - len(self._pending_configs), list(self._pending_configs.keys())) + _LOG.warning( + "Exceeded max iterations, but still have %d pending configs: %s", + len(self._pending_configs), + list(self._pending_configs.keys()), + ) return False return bool(self._pending_configs) diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index e9a522a683..dfaa345548 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -36,35 +36,44 @@ class MlosCoreOptimizer(Optimizer): """A wrapper class for the mlos_core optimizers.""" - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) - opt_type = getattr(OptimizerType, self._config.pop( - 'optimizer_type', DEFAULT_OPTIMIZER_TYPE.name)) + opt_type = getattr( + OptimizerType, self._config.pop("optimizer_type", DEFAULT_OPTIMIZER_TYPE.name) + ) if opt_type == OptimizerType.SMAC: - output_directory = self._config.get('output_directory') + output_directory = self._config.get("output_directory") if output_directory is not None: # If output_directory is specified, turn it into an absolute path. - self._config['output_directory'] = os.path.abspath(output_directory) + self._config["output_directory"] = os.path.abspath(output_directory) else: - _LOG.warning("SMAC optimizer output_directory was null. SMAC will use a temporary directory.") + _LOG.warning( + ( + "SMAC optimizer output_directory was null. " + "SMAC will use a temporary directory." + ) + ) # Make sure max_trials >= max_iterations. - if 'max_trials' not in self._config: - self._config['max_trials'] = self._max_iter - assert int(self._config['max_trials']) >= self._max_iter, \ - f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" + if "max_trials" not in self._config: + self._config["max_trials"] = self._max_iter + assert ( + int(self._config["max_trials"]) >= self._max_iter + ), f"max_trials {self._config.get('max_trials')} <= max_iterations {self._max_iter}" - if 'run_name' not in self._config and self.experiment_id: - self._config['run_name'] = self.experiment_id + if "run_name" not in self._config and self.experiment_id: + self._config["run_name"] = self.experiment_id - space_adapter_type = self._config.pop('space_adapter_type', None) - space_adapter_config = self._config.pop('space_adapter_config', {}) + space_adapter_type = self._config.pop("space_adapter_type", None) + space_adapter_config = self._config.pop("space_adapter_config", {}) if space_adapter_type is not None: space_adapter_type = getattr(SpaceAdapterType, space_adapter_type) @@ -78,9 +87,12 @@ def __init__(self, space_adapter_kwargs=space_adapter_config, ) - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: self._opt.cleanup() return super().__exit__(ex_type, ex_val, ex_tb) @@ -88,10 +100,12 @@ def __exit__(self, ex_type: Optional[Type[BaseException]], def name(self) -> str: return f"{self.__class__.__name__}:{self._opt.__class__.__name__}" - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False @@ -99,7 +113,8 @@ def bulk_register(self, df_configs = self._to_df(configs) # Impute missing values, if necessary df_scores = self._adjust_signs_df( - pd.DataFrame([{} if score is None else score for score in scores])) + pd.DataFrame([{} if score is None else score for score in scores]) + ) opt_targets = list(self._opt_targets) if status is not None: @@ -124,7 +139,7 @@ def bulk_register(self, def _adjust_signs_df(self, df_scores: pd.DataFrame) -> pd.DataFrame: """In-place adjust the signs of the scores for MINIMIZATION problem.""" - for (opt_target, opt_dir) in self._opt_targets.items(): + for opt_target, opt_dir in self._opt_targets.items(): df_scores[opt_target] *= opt_dir return df_scores @@ -146,7 +161,7 @@ def _to_df(self, configs: Sequence[Dict[str, TunableValue]]) -> pd.DataFrame: df_configs = pd.DataFrame(configs) tunables_names = list(self._tunables.get_param_values().keys()) missing_cols = set(tunables_names).difference(df_configs.columns) - for (tunable, _group) in self._tunables: + for tunable, _group in self._tunables: if tunable.name in missing_cols: df_configs[tunable.name] = tunable.default else: @@ -178,22 +193,31 @@ def suggest(self) -> TunableGroups: df_config, _metadata = self._opt.suggest(defaults=self._start_with_defaults) self._start_with_defaults = False _LOG.info("Iteration %d :: Suggest:\n%s", self._iter, df_config) - return tunables.assign( - configspace_data_to_tunable_values(df_config.loc[0].to_dict())) - - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: - registered_score = super().register(tunables, status, score) # Sign-adjusted for MINIMIZATION + return tunables.assign(configspace_data_to_tunable_values(df_config.loc[0].to_dict())) + + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: + registered_score = super().register( + tunables, status, score + ) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None df_config = self._to_df([tunables.get_param_values()]) _LOG.debug("Score: %s Dataframe:\n%s", registered_score, df_config) # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 - self._opt.register(configs=df_config, scores=pd.DataFrame([registered_score], dtype=float)) + self._opt.register( + configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) + ) return registered_score - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: (df_config, df_score, _df_context) = self._opt.get_best_observations() if len(df_config) == 0: return (None, None) diff --git a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py index 2d70512b1f..fd157db81a 100644 --- a/mlos_bench/mlos_bench/optimizers/mock_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mock_optimizer.py @@ -20,11 +20,13 @@ class MockOptimizer(TrackBestOptimizer): """Mock optimizer to test the Environment API.""" - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) rnd = random.Random(self.seed) self._random: Dict[str, Callable[[Tunable], TunableValue]] = { @@ -33,15 +35,17 @@ def __init__(self, "int": lambda tunable: rnd.randint(*tunable.range), } - def bulk_register(self, - configs: Sequence[dict], - scores: Sequence[Optional[Dict[str, TunableValue]]], - status: Optional[Sequence[Status]] = None) -> bool: + def bulk_register( + self, + configs: Sequence[dict], + scores: Sequence[Optional[Dict[str, TunableValue]]], + status: Optional[Sequence[Status]] = None, + ) -> bool: if not super().bulk_register(configs, scores, status): return False if status is None: status = [Status.SUCCEEDED] * len(configs) - for (params, score, trial_status) in zip(configs, scores, status): + for params, score, trial_status in zip(configs, scores, status): tunables = self._tunables.copy().assign(params) self.register(tunables, trial_status, score) if _LOG.isEnabledFor(logging.DEBUG): @@ -56,7 +60,7 @@ def suggest(self) -> TunableGroups: _LOG.info("Use default tunable values") self._start_with_defaults = False else: - for (tunable, _group) in tunables: + for tunable, _group in tunables: tunable.value = self._random[tunable.type](tunable) _LOG.info("Iteration %d :: Suggest: %s", self._iter, tunables) return tunables diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index d0c0e531ef..2f3f014943 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -23,11 +23,13 @@ class OneShotOptimizer(MockOptimizer): # TODO: Add support for multiple explicit configs (i.e., FewShot or Manual Optimizer) - #344 - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) _LOG.info("Run a single iteration for: %s", self._tunables) self._max_iter = 1 # Always run for just one iteration. diff --git a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py index e90f81a6ea..6ad8ab48d2 100644 --- a/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/track_best_optimizer.py @@ -20,17 +20,23 @@ class TrackBestOptimizer(Optimizer, metaclass=ABCMeta): """Base Optimizer class that keeps track of the best score and configuration.""" - def __init__(self, - tunables: TunableGroups, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + tunables: TunableGroups, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): super().__init__(tunables, config, global_config, service) self._best_config: Optional[TunableGroups] = None self._best_score: Optional[Dict[str, float]] = None - def register(self, tunables: TunableGroups, status: Status, - score: Optional[Dict[str, TunableValue]] = None) -> Optional[Dict[str, float]]: + def register( + self, + tunables: TunableGroups, + status: Status, + score: Optional[Dict[str, TunableValue]] = None, + ) -> Optional[Dict[str, float]]: registered_score = super().register(tunables, status, score) if status.is_succeeded() and self._is_better(registered_score): self._best_score = registered_score @@ -42,7 +48,7 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: if self._best_score is None: return True assert registered_score is not None - for (opt_target, best_score) in self._best_score.items(): + for opt_target, best_score in self._best_score.items(): score = registered_score[opt_target] if score < best_score: return True @@ -50,7 +56,9 @@ def _is_better(self, registered_score: Optional[Dict[str, float]]) -> bool: return False return False - def get_best_observation(self) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: + def get_best_observation( + self, + ) -> Union[Tuple[Dict[str, float], TunableGroups], Tuple[None, None]]: if self._best_score is None: return (None, None) score = self._get_scores(Status.SUCCEEDED, self._best_score) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index 44c15eb709..348cf1ffa0 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,16 +22,19 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[str] # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[ + str + ] # pylint: disable=protected-access,disable=unsubscriptable-object else: - EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access + EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access # Handle case sensitivity differences between platforms. # https://stackoverflow.com/a/19023293 -if sys.platform == 'win32': +if sys.platform == "win32": import nt # type: ignore[import-not-found] # pylint: disable=import-error # (3.8) + environ: EnvironType = nt.environ else: environ: EnvironType = os.environ -__all__ = ['environ'] +__all__ = ["environ"] diff --git a/mlos_bench/mlos_bench/run.py b/mlos_bench/mlos_bench/run.py index 85c8c2b0c5..57c48a87b9 100755 --- a/mlos_bench/mlos_bench/run.py +++ b/mlos_bench/mlos_bench/run.py @@ -20,8 +20,9 @@ _LOG = logging.getLogger(__name__) -def _main(argv: Optional[List[str]] = None - ) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: +def _main( + argv: Optional[List[str]] = None, +) -> Tuple[Optional[Dict[str, float]], Optional[TunableGroups]]: launcher = Launcher("mlos_bench", "Systems autotuning and benchmarking tool", argv=argv) diff --git a/mlos_bench/mlos_bench/schedulers/__init__.py b/mlos_bench/mlos_bench/schedulers/__init__.py index a269560b73..381261e53d 100644 --- a/mlos_bench/mlos_bench/schedulers/__init__.py +++ b/mlos_bench/mlos_bench/schedulers/__init__.py @@ -8,6 +8,6 @@ from mlos_bench.schedulers.sync_scheduler import SyncScheduler __all__ = [ - 'Scheduler', - 'SyncScheduler', + "Scheduler", + "SyncScheduler", ] diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index b2a7328ebb..c268aab14c 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -27,13 +27,16 @@ class Scheduler(metaclass=ABCMeta): # pylint: disable=too-many-instance-attributes """Base class for the optimization loop scheduling policies.""" - def __init__(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: Storage, - root_env_config: str): + def __init__( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: Storage, + root_env_config: str, + ): """ Create a new instance of the scheduler. The constructor of this and the derived classes is called by the persistence service after reading the class JSON @@ -56,8 +59,9 @@ def __init__(self, *, Path to the root environment configuration. """ self.global_config = global_config - config = merge_parameters(dest=config.copy(), source=global_config, - required_keys=["experiment_id", "trial_id"]) + config = merge_parameters( + dest=config.copy(), source=global_config, required_keys=["experiment_id", "trial_id"] + ) self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) @@ -67,7 +71,9 @@ def __init__(self, *, self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) if self._trial_config_repeat_count <= 0: - raise ValueError(f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}") + raise ValueError( + f"Invalid trial_config_repeat_count: {self._trial_config_repeat_count}" + ) self._do_teardown = bool(config.get("teardown", True)) @@ -91,7 +97,7 @@ def __repr__(self) -> str: """ return self.__class__.__name__ - def __enter__(self) -> 'Scheduler': + def __enter__(self) -> "Scheduler": """Enter the scheduler's context.""" _LOG.debug("Scheduler START :: %s", self) assert self.experiment is None @@ -111,10 +117,12 @@ def __enter__(self) -> 'Scheduler': ).__enter__() return self - def __exit__(self, - ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """Exit the context of the scheduler.""" if ex_val is None: _LOG.debug("Scheduler END :: %s", self) @@ -132,8 +140,12 @@ def __exit__(self, def start(self) -> None: """Start the optimization loop.""" assert self.experiment is not None - _LOG.info("START: Experiment: %s Env: %s Optimizer: %s", - self.experiment, self.environment, self.optimizer) + _LOG.info( + "START: Experiment: %s Env: %s Optimizer: %s", + self.experiment, + self.environment, + self.optimizer, + ) if _LOG.isEnabledFor(logging.INFO): _LOG.info("Root Environment:\n%s", self.environment.pprint()) @@ -191,27 +203,33 @@ def _schedule_new_optimizer_suggestions(self) -> bool: def schedule_trial(self, tunables: TunableGroups) -> None: """Add a configuration to the queue of trials.""" for repeat_i in range(1, self._trial_config_repeat_count + 1): - self._add_trial_to_queue(tunables, config={ - # Add some additional metadata to track for the trial such as the - # optimizer config used. - # Note: these values are unfortunately mutable at the moment. - # Consider them as hints of what the config was the trial *started*. - # It is possible that the experiment configs were changed - # between resuming the experiment (since that is not currently - # prevented). - "optimizer": self.optimizer.name, - "repeat_i": repeat_i, - "is_defaults": tunables.is_defaults, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(self.optimizer.targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) - - def _add_trial_to_queue(self, tunables: TunableGroups, - ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> None: + self._add_trial_to_queue( + tunables, + config={ + # Add some additional metadata to track for the trial such as the + # optimizer config used. + # Note: these values are unfortunately mutable at the moment. + # Consider them as hints of what the config was the trial *started*. + # It is possible that the experiment configs were changed + # between resuming the experiment (since that is not currently + # prevented). + "optimizer": self.optimizer.name, + "repeat_i": repeat_i, + "is_defaults": tunables.is_defaults, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(self.optimizer.targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) + + def _add_trial_to_queue( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: """ Add a configuration to the queue of trials. diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 0d3cfa0969..96cf15cdc9 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -49,7 +49,9 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = self.environment.run() # Block and wait for the final result. + (status, timestamp, results) = ( + self.environment.run() + ) # Block and wait for the final result. _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/__init__.py b/mlos_bench/mlos_bench/services/__init__.py index b9b0b51693..b768afb09c 100644 --- a/mlos_bench/mlos_bench/services/__init__.py +++ b/mlos_bench/mlos_bench/services/__init__.py @@ -9,7 +9,7 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'Service', - 'FileShareService', - 'LocalExecService', + "Service", + "FileShareService", + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index 07da98d11f..c941e0b132 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -17,10 +17,13 @@ class FileShareService(Service, SupportsFileShareOps, metaclass=ABCMeta): """An abstract base of all file shares.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share with a given config. @@ -38,12 +41,16 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) @abstractmethod - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -61,11 +68,18 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ params = params or {} - _LOG.info("Download from File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", remote_path, local_path, params) + _LOG.info( + "Download from File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + remote_path, + local_path, + params, + ) @abstractmethod - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. @@ -82,5 +96,10 @@ def upload(self, params: dict, local_path: str, remote_path: str, recursive: boo if True (the default), upload the entire directory tree. """ params = params or {} - _LOG.info("Upload to File Share %s recursively: %s -> %s (%s)", - "" if recursive else "non", local_path, remote_path, params) + _LOG.info( + "Upload to File Share %s recursively: %s -> %s (%s)", + "" if recursive else "non", + local_path, + remote_path, + params, + ) diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 5b8a93fee6..c5d9b78c87 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -22,11 +22,13 @@ class Service: """An abstract base of all Environment Services and used to build up mix-ins.""" @classmethod - def new(cls, - class_name: str, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def new( + cls, + class_name: str, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Factory method for a new service with a given config. @@ -53,11 +55,13 @@ def new(cls, assert issubclass(cls, Service) return instantiate_from_config(cls, class_name, config, global_config, parent) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new service with a given config. @@ -97,8 +101,10 @@ def __init__(self, _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) @staticmethod - def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None], - local_methods: Union[Dict[str, Callable], List[Callable]]) -> Dict[str, Callable]: + def merge_methods( + ext_methods: Union[Dict[str, Callable], List[Callable], None], + local_methods: Union[Dict[str, Callable], List[Callable]], + ) -> Dict[str, Callable]: """ Merge methods from the external caller with the local ones. @@ -135,9 +141,12 @@ def __enter__(self) -> "Service": self._in_context = True return self - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exit the Service mix-in context. @@ -174,9 +183,12 @@ def _enter_context(self) -> "Service": self._in_context = True return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: """ Exits the context for this particular Service instance. @@ -194,7 +206,8 @@ def _validate_json_config(self, config: dict) -> None: mechanism. """ if self.__class__ == Service: - # Skip over the case where instantiate a bare base Service class in order to build up a mix-in. + # Skip over the case where instantiate a bare base Service class in + # order to build up a mix-in. assert config == {} return json_config: dict = { @@ -259,10 +272,11 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None # Unfortunately, by creating a set, we may destroy the ability to # preserve the context enter/exit order, but hopefully it doesn't # matter. - svc_method.__self__ for _, svc_method in self._service_methods.items() + svc_method.__self__ + for _, svc_method in self._service_methods.items() # Note: some methods are actually stand alone functions, so we need # to filter them out. - if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service) + if hasattr(svc_method, "__self__") and isinstance(svc_method.__self__, Service) } def export(self) -> Dict[str, Callable]: diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 9532b8388b..2a90203fd1 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -59,11 +59,13 @@ class ConfigPersistenceService(Service, SupportsConfigLoading): BUILTIN_CONFIG_PATH = str(files("mlos_bench.config").joinpath("")).replace("\\", "/") - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of config persistence service. @@ -80,17 +82,22 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.resolve_path, - self.load_config, - self.prepare_class_load, - self.build_service, - self.build_environment, - self.load_services, - self.load_environment, - self.load_environment_list, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.resolve_path, + self.load_config, + self.prepare_class_load, + self.build_service, + self.build_environment, + self.load_services, + self.load_environment, + self.load_environment_list, + ], + ), ) self._config_loader_service = self @@ -118,8 +125,7 @@ def config_paths(self) -> List[str]: """ return list(self._config_path) # make a copy to avoid modifications - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -149,10 +155,11 @@ def resolve_path(self, file_path: str, _LOG.debug("Path not resolved: %s", file_path) return file_path - def load_config(self, - json_file_name: str, - schema_type: Optional[ConfigSchema], - ) -> Dict[str, Any]: + def load_config( + self, + json_file_name: str, + schema_type: Optional[ConfigSchema], + ) -> Dict[str, Any]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. This method is exported to be used as a service. @@ -171,16 +178,22 @@ def load_config(self, """ json_file_name = self.resolve_path(json_file_name) _LOG.info("Load config: %s", json_file_name) - with open(json_file_name, mode='r', encoding='utf-8') as fh_json: + with open(json_file_name, mode="r", encoding="utf-8") as fh_json: config = json5.load(fh_json) if schema_type is not None: try: schema_type.validate(config) except (ValidationError, SchemaError) as ex: - _LOG.error("Failed to validate config %s against schema type %s at %s", - json_file_name, schema_type.name, schema_type.value) - raise ValueError(f"Failed to validate config {json_file_name} against " + - f"schema type {schema_type.name} at {schema_type.value}") from ex + _LOG.error( + "Failed to validate config %s against schema type %s at %s", + json_file_name, + schema_type.name, + schema_type.value, + ) + raise ValueError( + f"Failed to validate config {json_file_name} against " + + f"schema type {schema_type.name} at {schema_type.value}" + ) from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated # them to avoid passing them on to other objects @@ -191,11 +204,14 @@ def load_config(self, del config["$schema"] else: _LOG.warning("Config %s is not validated against a schema.", json_file_name) - return config # type: ignore[no-any-return] + return config # type: ignore[no-any-return] - def prepare_class_load(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None) -> Tuple[str, Dict[str, Any]]: + def prepare_class_load( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + ) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. Mix-in the global parameters and resolve the local file system paths, where it is required. @@ -237,16 +253,22 @@ def prepare_class_load(self, config: Dict[str, Any], raise ValueError(f"Parameter {key} must be a string or a list") if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", + class_name, + json.dumps(class_config, indent=2), + ) return (class_name, class_config) - def build_optimizer(self, *, - tunables: TunableGroups, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> Optimizer: + def build_optimizer( + self, + *, + tunables: TunableGroups, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> Optimizer: """ Instantiation of mlos_bench Optimizer that depend on Service and TunableGroups. @@ -274,18 +296,24 @@ def build_optimizer(self, *, if tunables_path is not None: tunables = self._load_tunables(tunables_path, tunables) (class_name, class_config) = self.prepare_class_load(config, global_config) - inst = instantiate_from_config(Optimizer, class_name, # type: ignore[type-abstract] - tunables=tunables, - config=class_config, - global_config=global_config, - service=service) + inst = instantiate_from_config( + Optimizer, # type: ignore[type-abstract] + class_name, + tunables=tunables, + config=class_config, + global_config=global_config, + service=service, + ) _LOG.info("Created: Optimizer %s", inst) return inst - def build_storage(self, *, - service: Service, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None) -> "Storage": + def build_storage( + self, + *, + service: Service, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + ) -> "Storage": """ Instantiation of mlos_bench Storage objects. @@ -304,23 +332,31 @@ def build_storage(self, *, A new instance of the Storage class. """ (class_name, class_config) = self.prepare_class_load(config, global_config) + # pylint: disable=import-outside-toplevel from mlos_bench.storage.base_storage import ( - Storage, # pylint: disable=import-outside-toplevel + Storage, + ) + + inst = instantiate_from_config( + Storage, # type: ignore[type-abstract] + class_name, + config=class_config, + global_config=global_config, + service=service, ) - inst = instantiate_from_config(Storage, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - service=service) _LOG.info("Created: Storage %s", inst) return inst - def build_scheduler(self, *, - config: Dict[str, Any], - global_config: Dict[str, Any], - environment: Environment, - optimizer: Optimizer, - storage: "Storage", - root_env_config: str) -> "Scheduler": + def build_scheduler( + self, + *, + config: Dict[str, Any], + global_config: Dict[str, Any], + environment: Environment, + optimizer: Optimizer, + storage: "Storage", + root_env_config: str, + ) -> "Scheduler": """ Instantiation of mlos_bench Scheduler. @@ -345,25 +381,32 @@ def build_scheduler(self, *, A new instance of the Scheduler. """ (class_name, class_config) = self.prepare_class_load(config, global_config) + # pylint: disable=import-outside-toplevel from mlos_bench.schedulers.base_scheduler import ( - Scheduler, # pylint: disable=import-outside-toplevel + Scheduler, + ) + + inst = instantiate_from_config( + Scheduler, # type: ignore[type-abstract] + class_name, + config=class_config, + global_config=global_config, + environment=environment, + optimizer=optimizer, + storage=storage, + root_env_config=root_env_config, ) - inst = instantiate_from_config(Scheduler, class_name, # type: ignore[type-abstract] - config=class_config, - global_config=global_config, - environment=environment, - optimizer=optimizer, - storage=storage, - root_env_config=root_env_config) _LOG.info("Created: Scheduler %s", inst) return inst - def build_environment(self, # pylint: disable=too-many-arguments - config: Dict[str, Any], - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def build_environment( + self, # pylint: disable=too-many-arguments + config: Dict[str, Any], + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Factory method for a new environment with a given config. @@ -403,16 +446,24 @@ def build_environment(self, # pylint: disable=too-many-arguments tunables = self._load_tunables(env_tunables_path, tunables) _LOG.debug("Creating env: %s :: %s", env_name, env_class) - env = Environment.new(env_name=env_name, class_name=env_class, - config=env_config, global_config=global_config, - tunables=tunables, service=service) + env = Environment.new( + env_name=env_name, + class_name=env_class, + config=env_config, + global_config=global_config, + tunables=tunables, + service=service, + ) _LOG.info("Created env: %s :: %s", env_name, env) return env - def _build_standalone_service(self, config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_standalone_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -437,9 +488,12 @@ def _build_standalone_service(self, config: Dict[str, Any], _LOG.info("Created service: %s", service) return service - def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def _build_composite_service( + self, + config_list: Iterable[Dict[str, Any]], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -465,18 +519,21 @@ def _build_composite_service(self, config_list: Iterable[Dict[str, Any]], service.register(parent.export()) for config in config_list: - service.register(self._build_standalone_service( - config, global_config, service).export()) + service.register( + self._build_standalone_service(config, global_config, service).export() + ) if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Created mix-in service: %s", service) return service - def build_service(self, - config: Dict[str, Any], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def build_service( + self, + config: Dict[str, Any], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Factory method for a new service with a given config. @@ -498,8 +555,7 @@ def build_service(self, services from the list plus the parent mix-in. """ if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Build service from config:\n%s", - json.dumps(config, indent=2)) + _LOG.debug("Build service from config:\n%s", json.dumps(config, indent=2)) assert isinstance(config, dict) config_list: List[Dict[str, Any]] @@ -514,12 +570,14 @@ def build_service(self, return self._build_composite_service(config_list, global_config, parent) - def load_environment(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> Environment: + def load_environment( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> Environment: """ Load and build new environment from the config file. @@ -546,12 +604,14 @@ def load_environment(self, # pylint: disable=too-many-arguments assert isinstance(config, dict) return self.build_environment(config, tunables, global_config, parent_args, service) - def load_environment_list(self, # pylint: disable=too-many-arguments - json_file_name: str, - tunables: TunableGroups, - global_config: Optional[Dict[str, Any]] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional[Service] = None) -> List[Environment]: + def load_environment_list( + self, # pylint: disable=too-many-arguments + json_file_name: str, + tunables: TunableGroups, + global_config: Optional[Dict[str, Any]] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional[Service] = None, + ) -> List[Environment]: """ Load and build a list of environments from the config file. @@ -576,13 +636,14 @@ def load_environment_list(self, # pylint: disable=too-many-arguments A list of new benchmarking environments. """ config = self.load_config(json_file_name, ConfigSchema.ENVIRONMENT) - return [ - self.build_environment(config, tunables, global_config, parent_args, service) - ] + return [self.build_environment(config, tunables, global_config, parent_args, service)] - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None) -> Service: + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + ) -> Service: """ Read the configuration files and bundle all service methods from those configs into a single Service object. @@ -601,16 +662,16 @@ def load_services(self, json_file_names: Iterable[str], service : Service A collection of service methods. """ - _LOG.info("Load services: %s parent: %s", - json_file_names, parent.__class__.__name__) + _LOG.info("Load services: %s parent: %s", json_file_names, parent.__class__.__name__) service = Service({}, global_config, parent) for fname in json_file_names: config = self.load_config(fname, ConfigSchema.SERVICE) service.register(self.build_service(config, global_config, service).export()) return service - def _load_tunables(self, json_file_names: Iterable[str], - parent: TunableGroups) -> TunableGroups: + def _load_tunables( + self, json_file_names: Iterable[str], parent: TunableGroups + ) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent TunableGroup. diff --git a/mlos_bench/mlos_bench/services/local/__init__.py b/mlos_bench/mlos_bench/services/local/__init__.py index bf1361024a..afe9f05d20 100644 --- a/mlos_bench/mlos_bench/services/local/__init__.py +++ b/mlos_bench/mlos_bench/services/local/__init__.py @@ -7,5 +7,5 @@ from mlos_bench.services.local.local_exec import LocalExecService __all__ = [ - 'LocalExecService', + "LocalExecService", ] diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index 189a54b210..f595c75a89 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -79,11 +79,13 @@ class LocalExecService(TempDirContextService, SupportsLocalExec): vs the target environment. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service to run scripts locally. @@ -100,14 +102,16 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, self.merge_methods(methods, [self.local_exec]) ) self.abort_on_error = self.config.get("abort_on_error", True) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -175,9 +179,9 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: subcmd_tokens.insert(0, sys.executable) return subcmd_tokens - def _local_exec_script(self, script_line: str, - env_params: Optional[Mapping[str, "TunableValue"]], - cwd: str) -> Tuple[int, str, str]: + def _local_exec_script( + self, script_line: str, env_params: Optional[Mapping[str, "TunableValue"]], cwd: str + ) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. @@ -206,7 +210,7 @@ def _local_exec_script(self, script_line: str, if env_params: env = {key: str(val) for (key, val) in env_params.items()} - if sys.platform == 'win32': + if sys.platform == "win32": # A hack to run Python on Windows with env variables set: env_copy = environ.copy() env_copy["PYTHONPATH"] = "" @@ -214,7 +218,7 @@ def _local_exec_script(self, script_line: str, env = env_copy try: - if sys.platform != 'win32': + if sys.platform != "win32": cmd = [" ".join(cmd)] _LOG.info("Run: %s", cmd) @@ -222,8 +226,15 @@ def _local_exec_script(self, script_line: str, _LOG.debug("Expands to: %s", Template(" ".join(cmd)).safe_substitute(env)) _LOG.debug("Current working dir: %s", cwd) - proc = subprocess.run(cmd, env=env or None, cwd=cwd, shell=True, - text=True, check=False, capture_output=True) + proc = subprocess.run( + cmd, + env=env or None, + cwd=cwd, + shell=True, + text=True, + check=False, + capture_output=True, + ) _LOG.debug("Run: return code = %d", proc.returncode) return (proc.returncode, proc.stdout, proc.stderr) diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index 4221754cb0..06bb32bc5f 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -26,11 +26,13 @@ class TempDirContextService(Service, metaclass=abc.ABCMeta): supposed to be used as a standalone service. """ - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of a service that provides temporary directory context for local exec service. @@ -48,8 +50,7 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.temp_dir_context]) + config, global_config, parent, self.merge_methods(methods, [self.temp_dir_context]) ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: @@ -59,7 +60,9 @@ def __init__(self, self._temp_dir = self._config_loader_service.resolve_path(self._temp_dir) _LOG.info("%s: temp dir: %s", self, self._temp_dir) - def temp_dir_context(self, path: Optional[str] = None) -> Union[TemporaryDirectory, nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/__init__.py b/mlos_bench/mlos_bench/services/remote/azure/__init__.py index 0a148250c3..cfe12e3c46 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/__init__.py +++ b/mlos_bench/mlos_bench/services/remote/azure/__init__.py @@ -11,9 +11,9 @@ from mlos_bench.services.remote.azure.azure_vm_services import AzureVMService __all__ = [ - 'AzureAuthService', - 'AzureFileShareService', - 'AzureNetworkService', - 'AzureSaaSConfigService', - 'AzureVMService', + "AzureAuthService", + "AzureFileShareService", + "AzureNetworkService", + "AzureSaaSConfigService", + "AzureVMService", ] diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index 350ecd6e5f..bded5fb99e 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -23,13 +23,15 @@ class AzureAuthService(Service, SupportsAuth): """Helper methods to get access to Azure services.""" - _REQ_INTERVAL = 300 # = 5 min - - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + _REQ_INTERVAL = 300 # = 5 min + + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure authentication services proxy. @@ -46,11 +48,16 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) # This parameter can come from command line as strings, so conversion is needed. @@ -66,12 +73,13 @@ def __init__(self, # Verify info required for SP auth early if "spClientId" in self.config: check_required_params( - self.config, { + self.config, + { "spClientId", "keyVaultName", "certName", "tenant", - } + }, ) def _init_sp(self) -> None: @@ -100,7 +108,9 @@ def _init_sp(self) -> None: cert_bytes = b64decode(secret.value) # Reauthenticate as the service principal. - self._cred = azure_id.CertificateCredential(tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes) + self._cred = azure_id.CertificateCredential( + tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes + ) def get_access_token(self) -> str: """Get the access token from Azure CLI, if expired.""" diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index dc2c049c1e..24c4242e8f 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -25,29 +25,32 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): """Helper methods to manage and deploy Azure resources via REST APIs.""" - _POLL_INTERVAL = 4 # seconds - _POLL_TIMEOUT = 300 # seconds - _REQUEST_TIMEOUT = 5 # seconds + _POLL_INTERVAL = 4 # seconds + _POLL_TIMEOUT = 300 # seconds + _REQUEST_TIMEOUT = 5 # seconds _REQUEST_TOTAL_RETRIES = 10 # Total number retries for each request - _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) + # Delay (seconds) between retries: {backoff factor} * (2 ** ({number of previous retries})) + _REQUEST_RETRY_BACKOFF_FACTOR = 0.3 # Azure Resources Deployment REST API as described in # https://docs.microsoft.com/en-us/rest/api/resources/deployments _URL_DEPLOY = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Resources" + - "/deployments/{deployment_name}" + - "?api-version=2022-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Resources" + + "/deployments/{deployment_name}" + + "?api-version=2022-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an Azure Services proxy. @@ -65,32 +68,44 @@ def __init__(self, """ super().__init__(config, global_config, parent, methods) - check_required_params(self.config, [ - "subscription", - "resourceGroup", - ]) + check_required_params( + self.config, + [ + "subscription", + "resourceGroup", + ], + ) # These parameters can come from command line as strings, so conversion is needed. self._poll_interval = float(self.config.get("pollInterval", self._POLL_INTERVAL)) self._poll_timeout = float(self.config.get("pollTimeout", self._POLL_TIMEOUT)) self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - self._total_retries = int(self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES)) - self._backoff_factor = float(self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR)) + self._total_retries = int( + self.config.get("requestTotalRetries", self._REQUEST_TOTAL_RETRIES) + ) + self._backoff_factor = float( + self.config.get("requestBackoffFactor", self._REQUEST_RETRY_BACKOFF_FACTOR) + ) self._deploy_template = {} self._deploy_params = {} if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config['deploymentTemplatePath'], schema_type=None) + self.config["deploymentTemplatePath"], schema_type=None + ) assert template is not None and isinstance(template, dict) self._deploy_template = template # Allow for recursive variable expansion as we do with global params and const_args. - deploy_params = DictTemplater(self.config['deploymentTemplateParameters']).expand_vars(extra_source_dict=global_config) + deploy_params = DictTemplater(self.config["deploymentTemplateParameters"]).expand_vars( + extra_source_dict=global_config + ) self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) else: - _LOG.info("No deploymentTemplatePath provided. Deployment services will be unavailable.") + _LOG.info( + "No deploymentTemplatePath provided. Deployment services will be unavailable." + ) @property def deploy_params(self) -> dict: @@ -123,14 +138,16 @@ def _get_session(self, params: dict) -> requests.Session: session = requests.Session() session.mount( "https://", - HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor))) + HTTPAdapter(max_retries=Retry(total=total_retries, backoff_factor=backoff_factor)), + ) session.headers.update(self._get_headers()) return session def _get_headers(self) -> dict: """Get the headers for the REST API calls.""" - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() @staticmethod @@ -226,9 +243,11 @@ def _check_operation_status(self, params: dict) -> Tuple[Status, dict]: return (Status.FAILED, {}) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) if response.status_code == 200: output = response.json() @@ -251,7 +270,8 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for resource being deployed; otherwise, wait for successful deprovisioning. + If True, wait for resource being deployed; otherwise, wait for + successful deprovisioning. Returns ------- @@ -261,12 +281,16 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ params = self._set_default_params(params) - _LOG.info("Wait for %s to %s", params.get("deploymentName"), - "provision" if is_setup else "deprovision") + _LOG.info( + "Wait for %s to %s", + params.get("deploymentName"), + "provision" if is_setup else "deprovision", + ) return self._wait_while(self._check_deployment, Status.PENDING, params) - def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], - loop_status: Status, params: dict) -> Tuple[Status, dict]: + def _wait_while( + self, func: Callable[[dict], Tuple[Status, dict]], loop_status: Status, params: dict + ) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return TIMED_OUT when timing out. @@ -288,12 +312,18 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) poll_period = params.get("pollInterval", self._poll_interval) - _LOG.debug("Wait for %s status %s :: poll %.2f timeout %d s", - config["deploymentName"], loop_status, poll_period, self._poll_timeout) + _LOG.debug( + "Wait for %s status %s :: poll %.2f timeout %d s", + config["deploymentName"], + loop_status, + poll_period, + self._poll_timeout, + ) ts_timeout = time.time() + self._poll_timeout poll_delay = poll_period @@ -317,7 +347,9 @@ def _wait_while(self, func: Callable[[dict], Tuple[Status, dict]], _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment( + self, params: dict + ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -342,7 +374,7 @@ def _check_deployment(self, params: dict) -> Tuple[Status, dict]: # pylint: di "subscription", "resourceGroup", "deploymentName", - ] + ], ) _LOG.info("Check deployment: %s", config["deploymentName"]) @@ -403,13 +435,18 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: if not self._deploy_template: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) - config = merge_parameters(dest=self.config.copy(), source=params, required_keys=["deploymentName"]) + config = merge_parameters( + dest=self.config.copy(), source=params, required_keys=["deploymentName"] + ) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) params = merge_parameters(dest=self._deploy_params.copy(), source=params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Deploy: %s merged params ::\n%s", - config["deploymentName"], json.dumps(params, indent=2)) + _LOG.debug( + "Deploy: %s merged params ::\n%s", + config["deploymentName"], + json.dumps(params, indent=2), + ) url = self._URL_DEPLOY.format( subscription=config["subscription"], @@ -422,22 +459,26 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: "mode": "Incremental", "template": self._deploy_template, "parameters": { - key: {"value": val} for (key, val) in params.items() + key: {"value": val} + for (key, val) in params.items() if key in self._deploy_template.get("parameters", {}) - } + }, } } if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) - response = requests.put(url, json=json_req, - headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put( + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 653963922d..ddd41afcc2 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -23,11 +23,13 @@ class AzureFileShareService(FileShareService): _SHARE_URL = "https://{account_name}.file.core.windows.net/{fs_name}" - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new file share Service for Azure environments with a given config. @@ -45,16 +47,19 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) check_required_params( - self.config, { + self.config, + { "storageAccountName", "storageFileShareName", "storageAccountKey", - } + }, ) self._share_client = ShareClient.from_share_url( @@ -65,7 +70,9 @@ def __init__(self, credential=self.config["storageAccountKey"], ) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) if dir_client.exists(): @@ -90,7 +97,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b # Translate into non-Azure exception: raise FileNotFoundError(f"Cannot download: {remote_path}") from ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index ff6eb160fd..9c66fc7b0c 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -26,22 +26,24 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # Azure Compute REST API calls as described in # https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 - # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 + # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 # pylint: disable=line-too-long # noqa _URL_DEPROVISION = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Network" + - "/virtualNetwork/{vnet_name}" + - "/delete" + - "?api-version=2023-05-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Network" + + "/virtualNetwork/{vnet_name}" + + "/delete" + + "?api-version=2023-05-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure Network services proxy. @@ -58,25 +60,34 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsNetworkProvisioning - self.provision_network, - self.deprovision_network, - self.wait_network_deployment, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsNetworkProvisioning + self.provision_network, + self.deprovision_network, + self.wait_network_deployment, + ], + ), ) if not self._deploy_template: - raise ValueError("AzureNetworkService requires a deployment template:\n" - + f"config={config}\nglobal_config={global_config}") + raise ValueError( + "AzureNetworkService requires a deployment template:\n" + + f"config={config}\nglobal_config={global_config}" + ) - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + ) return params def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -147,15 +158,18 @@ def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple "resourceGroup", "deploymentName", "vnetName", - ] + ], ) _LOG.info("Deprovision Network: %s", config["vnetName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) - (status, results) = self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vnet_name=config["vnetName"], - )) + (status, results) = self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vnet_name=config["vnetName"], + ), + ) if ignore_errors and status == Status.FAILED: _LOG.warning("Ignoring error: %s", results) status = Status.SUCCEEDED diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index b78a069c62..9a2081c90f 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -28,20 +28,22 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): # https://learn.microsoft.com/en-us/rest/api/mariadb/configurations _URL_CONFIGURE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/{provider}" + - "/{server_type}/{vm_name}" + - "/{update}" + - "?api-version={api_version}" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/{provider}" + + "/{server_type}/{vm_name}" + + "/{update}" + + "?api-version={api_version}" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure services proxy. @@ -58,18 +60,20 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.configure, - self.is_config_pending - ]) + config, + global_config, + parent, + self.merge_methods(methods, [self.configure, self.is_config_pending]), ) - check_required_params(self.config, { - "subscription", - "resourceGroup", - "provider", - }) + check_required_params( + self.config, + { + "subscription", + "resourceGroup", + "provider", + }, + ) # Provide sane defaults for known DB providers. provider = self.config.get("provider") @@ -113,8 +117,7 @@ def __init__(self, # These parameters can come from command line as strings, so conversion is needed. self._request_timeout = float(self.config.get("requestTimeout", self._REQUEST_TIMEOUT)) - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service. @@ -152,31 +155,36 @@ def is_config_pending(self, config: Dict[str, Any]) -> Tuple[Status, dict]: If "isConfigPendingReboot" is set to True, rebooting a VM is necessary. Status is one of {PENDING, TIMED_OUT, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_get.format(vm_name=config["vmName"]) _LOG.debug("Request: GET %s", url) - response = requests.put( - url, headers=self._get_headers(), timeout=self._request_timeout) + response = requests.put(url, headers=self._get_headers(), timeout=self._request_timeout) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) if response.status_code != 200: return (Status.FAILED, {}) # Currently, Azure Flex servers require a VM reboot. - return (Status.SUCCEEDED, {"isConfigPendingReboot": any( - {'False': False, 'True': True}[val['properties']['isConfigPendingRestart']] - for val in response.json()['value'] - )}) + return ( + Status.SUCCEEDED, + { + "isConfigPendingReboot": any( + {"False": False, "True": True}[val["properties"]["isConfigPendingRestart"]] + for val in response.json()["value"] + ) + }, + ) def _get_headers(self) -> dict: """Get the headers for the REST API calls.""" - assert self._parent is not None and isinstance(self._parent, SupportsAuth), \ - "Authorization service not provided. Include service-auth.jsonc?" + assert self._parent is not None and isinstance( + self._parent, SupportsAuth + ), "Authorization service not provided. Include service-auth.jsonc?" return self._parent.get_auth_headers() - def _config_one(self, config: Dict[str, Any], - param_name: str, param_value: Any) -> Tuple[Status, dict]: + def _config_one( + self, config: Dict[str, Any], param_name: str, param_value: Any + ) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -195,13 +203,15 @@ def _config_one(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"], param_name=param_name) _LOG.debug("Request: PUT %s", url) - response = requests.put(url, headers=self._get_headers(), - json={"properties": {"value": str(param_value)}}, - timeout=self._request_timeout) + response = requests.put( + url, + headers=self._get_headers(), + json={"properties": {"value": str(param_value)}}, + timeout=self._request_timeout, + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) @@ -209,8 +219,7 @@ def _config_one(self, config: Dict[str, Any], return (Status.SUCCEEDED, {}) return (Status.FAILED, {}) - def _config_many(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_many(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Update the parameters of an Azure DB service one-by-one. (If batch API is not available for it). @@ -228,14 +237,13 @@ def _config_many(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - for (param_name, param_value) in params.items(): + for param_name, param_value in params.items(): (status, result) = self._config_one(config, param_name, param_value) if not status.is_succeeded(): return (status, result) return (Status.SUCCEEDED, {}) - def _config_batch(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple[Status, dict]: + def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple[Status, dict]: """ Batch update the parameters of an Azure DB service. @@ -252,19 +260,18 @@ def _config_batch(self, config: Dict[str, Any], A pair of Status and result. The result is always {}. Status is one of {PENDING, SUCCEEDED, FAILED} """ - config = merge_parameters( - dest=self.config.copy(), source=config, required_keys=["vmName"]) + config = merge_parameters(dest=self.config.copy(), source=config, required_keys=["vmName"]) url = self._url_config_set.format(vm_name=config["vmName"]) json_req = { "value": [ - {"name": key, "properties": {"value": str(val)}} - for (key, val) in params.items() + {"name": key, "properties": {"value": str(val)}} for (key, val) in params.items() ], # "resetAllToDefault": "True" } _LOG.debug("Request: POST %s", url) - response = requests.post(url, headers=self._get_headers(), - json=json_req, timeout=self._request_timeout) + response = requests.post( + url, headers=self._get_headers(), json=json_req, timeout=self._request_timeout + ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: return (Status.TIMED_OUT, {}) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 384618415d..06e71780e8 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -24,7 +24,13 @@ _LOG = logging.getLogger(__name__) -class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps, SupportsRemoteExec): +class AzureVMService( + AzureDeploymentService, + SupportsHostProvisioning, + SupportsHostOps, + SupportsOSOps, + SupportsRemoteExec, +): """Helper methods to manage VMs on Azure.""" # pylint: disable=too-many-ancestors @@ -34,35 +40,35 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/start" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/start" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/powerOff" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/powerOff" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/deallocate" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/deallocate" + + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -84,31 +90,33 @@ class AzureVMService(AzureDeploymentService, SupportsHostProvisioning, SupportsH # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/restart" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/restart" + + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( - "https://management.azure.com" + - "/subscriptions/{subscription}" + - "/resourceGroups/{resource_group}" + - "/providers/Microsoft.Compute" + - "/virtualMachines/{vm_name}" + - "/runCommand" + - "?api-version=2022-03-01" + "https://management.azure.com" + + "/subscriptions/{subscription}" + + "/resourceGroups/{resource_group}" + + "/providers/Microsoft.Compute" + + "/virtualMachines/{vm_name}" + + "/runCommand" + + "?api-version=2022-03-01" ) - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of Azure VM services proxy. @@ -125,26 +133,31 @@ def __init__(self, New methods to register with the service. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - # SupportsHostProvisioning - self.provision_host, - self.deprovision_host, - self.deallocate_host, - self.wait_host_deployment, - # SupportsHostOps - self.start_host, - self.stop_host, - self.restart_host, - self.wait_host_operation, - # SupportsOSOps - self.shutdown, - self.reboot, - self.wait_os_operation, - # SupportsRemoteExec - self.remote_exec, - self.get_remote_exec_results, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + # SupportsHostProvisioning + self.provision_host, + self.deprovision_host, + self.deallocate_host, + self.wait_host_deployment, + # SupportsHostOps + self.start_host, + self.stop_host, + self.restart_host, + self.wait_host_operation, + # SupportsOSOps + self.shutdown, + self.reboot, + self.wait_os_operation, + # SupportsRemoteExec + self.remote_exec, + self.get_remote_exec_results, + ], + ), ) # As a convenience, allow reading customData out of a file, rather than @@ -153,19 +166,23 @@ def __init__(self, # can be done using the `base64()` string function inside the ARM template. self._custom_data_file = self.config.get("customDataFile", None) if self._custom_data_file: - if self._deploy_params.get('customData', None): + if self._deploy_params.get("customData", None): raise ValueError("Both customDataFile and customData are specified.") - self._custom_data_file = self.config_loader_service.resolve_path(self._custom_data_file) - with open(self._custom_data_file, 'r', encoding='utf-8') as custom_data_fh: + self._custom_data_file = self.config_loader_service.resolve_path( + self._custom_data_file + ) + with open(self._custom_data_file, "r", encoding="utf-8") as custom_data_fh: self._deploy_params["customData"] = custom_data_fh.read() - def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use + def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self-use # Try and provide a semi sane default for the deploymentName if not provided # since this is a common way to set the deploymentName and can same some # config work for the caller. if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" - _LOG.info("deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"]) + _LOG.info( + "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + ) return params def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dict]: @@ -260,16 +277,19 @@ def deprovision_host(self, params: dict) -> Tuple[Status, dict]: "resourceGroup", "deploymentName", "vmName", - ] + ], ) _LOG.info("Deprovision VM: %s", config["vmName"]) _LOG.info("Deprovision deployment: %s", config["deploymentName"]) # TODO: Properly deprovision *all* resources specified in the ARM template. - return self._azure_rest_api_post_helper(config, self._URL_DEPROVISION.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEPROVISION.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def deallocate_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -298,14 +318,17 @@ def deallocate_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Deallocate VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_DEALLOCATE.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_DEALLOCATE.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def start_host(self, params: dict) -> Tuple[Status, dict]: """ @@ -330,14 +353,17 @@ def start_host(self, params: dict) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Start VM: %s :: %s", config["vmName"], params) - return self._azure_rest_api_post_helper(config, self._URL_START.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_START.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: """ @@ -364,14 +390,17 @@ def stop_host(self, params: dict, force: bool = False) -> Tuple[Status, dict]: "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Stop VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_STOP.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_STOP.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def shutdown(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.stop_host(params, force) @@ -401,20 +430,24 @@ def restart_host(self, params: dict, force: bool = False) -> Tuple[Status, dict] "subscription", "resourceGroup", "vmName", - ] + ], ) _LOG.info("Reboot VM: %s", config["vmName"]) - return self._azure_rest_api_post_helper(config, self._URL_REBOOT.format( - subscription=config["subscription"], - resource_group=config["resourceGroup"], - vm_name=config["vmName"], - )) + return self._azure_rest_api_post_helper( + config, + self._URL_REBOOT.format( + subscription=config["subscription"], + resource_group=config["resourceGroup"], + vm_name=config["vmName"], + ), + ) def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple[Status, dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple[Status, dict]: """ Run a command on Azure VM. @@ -444,7 +477,7 @@ def remote_exec(self, script: Iterable[str], config: dict, "subscription", "resourceGroup", "vmName", - ] + ], ) if _LOG.isEnabledFor(logging.INFO): @@ -453,7 +486,7 @@ def remote_exec(self, script: Iterable[str], config: dict, json_req = { "commandId": "RunShellScript", "script": list(script), - "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()] + "parameters": [{"name": key, "value": val} for (key, val) in env_params.items()], } url = self._URL_REXEC_RUN.format( @@ -466,12 +499,15 @@ def remote_exec(self, script: Iterable[str], config: dict, _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout) + url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + ) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Response: %s\n%s", response, - json.dumps(response.json(), indent=2) - if response.content else "") + _LOG.debug( + "Response: %s\n%s", + response, + json.dumps(response.json(), indent=2) if response.content else "", + ) else: _LOG.info("Response: %s", response) @@ -479,10 +515,10 @@ def remote_exec(self, script: Iterable[str], config: dict, # TODO: extract the results from JSON response return (Status.SUCCEEDED, config) elif response.status_code == 202: - return (Status.PENDING, { - **config, - "asyncResultsUrl": response.headers.get("Azure-AsyncOperation") - }) + return ( + Status.PENDING, + {**config, "asyncResultsUrl": response.headers.get("Azure-AsyncOperation")}, + ) else: _LOG.error("Response: %s :: %s", response, response.text) # _LOG.error("Bad Request:\n%s", response.request.body) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index 94947f69b0..db44a5411d 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -27,9 +27,14 @@ class CopyMode(Enum): class SshFileShareService(FileShareService, SshService): """A collection of functions for interacting with SSH servers as file shares.""" - async def _start_file_copy(self, params: dict, mode: CopyMode, - local_path: str, remote_path: str, - recursive: bool = True) -> None: + async def _start_file_copy( + self, + params: dict, + mode: CopyMode, + local_path: str, + remote_path: str, + recursive: bool = True, + ) -> None: # pylint: disable=too-many-arguments """ Starts a file copy operation. @@ -37,7 +42,8 @@ async def _start_file_copy(self, params: dict, mode: CopyMode, Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for + establishing the connection). mode : CopyMode Whether to download or upload the file. local_path : str @@ -69,40 +75,52 @@ async def _start_file_copy(self, params: dict, mode: CopyMode, raise ValueError(f"Unknown copy mode: {mode}") return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive)) + self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: - _LOG.error("Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex) + _LOG.error( + "Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex + ) if isinstance(ex, SFTPNoSuchFile) or ( - isinstance(ex, SFTPFailure) and ex.code == 4 - and any(msg.lower() in ex.reason.lower() for msg in ("File not found", "No such file or directory")) + isinstance(ex, SFTPFailure) + and ex.code == 4 + and any( + msg.lower() in ex.reason.lower() + for msg in ("File not found", "No such file or directory") + ) ): _LOG.warning("File %s does not exist on %s", remote_path, params) raise FileNotFoundError(f"File {remote_path} does not exist on {params}") from ex raise ex - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: params = merge_parameters( dest=self.config.copy(), source=params, required_keys=[ "ssh_hostname", - ] + ], ) super().upload(params, local_path, remote_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive)) + self._start_file_copy(params, CopyMode.UPLOAD, local_path, remote_path, recursive) + ) try: file_copy_future.result() except (OSError, SFTPError) as ex: diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index 26e886b83d..db7dbdffe0 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -25,11 +25,13 @@ class SshHostService(SshService, SupportsOSOps, SupportsRemoteExec): # pylint: disable=too-many-instance-attributes - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of an SSH Service. @@ -48,24 +50,33 @@ def __init__(self, # Same methods are also provided by the AzureVMService class # pylint: disable=duplicate-code super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.shutdown, - self.reboot, - self.wait_os_operation, - self.remote_exec, - self.get_remote_exec_results, - ])) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.shutdown, + self.reboot, + self.wait_os_operation, + self.remote_exec, + self.get_remote_exec_results, + ], + ), + ) self._shell = self.config.get("ssh_shell", "/bin/bash") - async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) -> SSHCompletedProcess: + async def _run_cmd( + self, params: dict, script: Iterable[str], env_params: dict + ) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. Parameters ---------- params : dict - Flat dictionary of (key, value) pairs of parameters (used for establishing the connection). + Flat dictionary of (key, value) pairs of parameters (used for + establishing the connection). cmd : str Command(s) to run via shell. @@ -78,19 +89,23 @@ async def _run_cmd(self, params: dict, script: Iterable[str], env_params: dict) # Script should be an iterable of lines, not an iterable string. script = [script] connection, _ = await self._get_client_connection(params) - # Note: passing environment variables to SSH servers is typically restricted to just some LC_* values. + # Note: passing environment variables to SSH servers is typically restricted + # to just some LC_* values. # Handle transferring environment variables by making a script to set them. env_script_lines = [f"export {name}='{value}'" for (name, value) in env_params.items()] - script_lines = env_script_lines + [line_split for line in script for line_split in line.splitlines()] + script_lines = env_script_lines + [ + line_split for line in script for line_split in line.splitlines() + ] # Note: connection.run() uses "exec" with a shell by default. - script_str = '\n'.join(script_lines) + script_str = "\n".join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) - return await connection.run(script_str, - check=False, - timeout=self._request_timeout, - env=env_params) + return await connection.run( + script_str, check=False, timeout=self._request_timeout, env=env_params + ) - def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -117,9 +132,11 @@ def remote_exec(self, script: Iterable[str], config: dict, env_params: dict) -> source=config, required_keys=[ "ssh_hostname", - ] + ], + ) + config["asyncRemoteExecResultsFuture"] = self._run_coroutine( + self._run_cmd(config, script, env_params) ) - config["asyncRemoteExecResultsFuture"] = self._run_coroutine(self._run_cmd(config, script, env_params)) return (Status.PENDING, config) def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: @@ -150,7 +167,11 @@ def get_remote_exec_results(self, config: dict) -> Tuple["Status", dict]: stdout = result.stdout.decode() if isinstance(result.stdout, bytes) else result.stdout stderr = result.stderr.decode() if isinstance(result.stderr, bytes) else result.stderr return ( - Status.SUCCEEDED if result.exit_status == 0 and result.returncode == 0 else Status.FAILED, + ( + Status.SUCCEEDED + if result.exit_status == 0 and result.returncode == 0 + else Status.FAILED + ), { "stdout": stdout, "stderr": stderr, @@ -183,9 +204,9 @@ def _exec_os_op(self, cmd_opts_list: List[str], params: dict) -> Tuple[Status, d source=params, required_keys=[ "ssh_hostname", - ] + ], ) - cmd_opts = ' '.join([f"'{cmd}'" for cmd in cmd_opts_list]) + cmd_opts = " ".join([f"'{cmd}'" for cmd in cmd_opts_list]) script = rf""" if [[ $EUID -ne 0 ]]; then sudo=$(command -v sudo) @@ -220,10 +241,10 @@ def shutdown(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -h now', - 'poweroff', - 'halt -p', - 'systemctl poweroff', + "shutdown -h now", + "poweroff", + "halt -p", + "systemctl poweroff", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) @@ -245,11 +266,11 @@ def reboot(self, params: dict, force: bool = False) -> Tuple[Status, dict]: Status is one of {PENDING, SUCCEEDED, FAILED} """ cmd_opts_list = [ - 'shutdown -r now', - 'reboot', - 'halt --reboot', - 'systemctl reboot', - 'kill -KILL 1; kill -KILL -1' if force else 'kill -TERM 1; kill -TERM -1', + "shutdown -r now", + "reboot", + "halt --reboot", + "systemctl reboot", + "kill -KILL 1; kill -KILL -1" if force else "kill -TERM 1; kill -TERM -1", ] return self._exec_os_op(cmd_opts_list=cmd_opts_list, params=params) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 272f908c78..8c0b2b8b7a 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -48,8 +48,8 @@ class SshClient(asyncssh.SSHClient): command. """ - _CONNECTION_PENDING = 'INIT' - _CONNECTION_LOST = 'LOST' + _CONNECTION_PENDING = "INIT" + _CONNECTION_LOST = "LOST" def __init__(self, *args: tuple, **kwargs: dict): self._connection_id: str = SshClient._CONNECTION_PENDING @@ -63,12 +63,16 @@ def __repr__(self) -> str: @staticmethod def id_from_connection(connection: SSHClientConnection) -> str: """Gets a unique id repr for the connection.""" - return f"{connection._username}@{connection._host}:{connection._port}" # pylint: disable=protected-access + # pylint: disable=protected-access + return f"{connection._username}@{connection._host}:{connection._port}" @staticmethod def id_from_params(connect_params: dict) -> str: """Gets a unique id repr for the connection.""" - return f"{connect_params.get('username')}@{connect_params['host']}:{connect_params.get('port')}" + return ( + f"{connect_params.get('username')}@{connect_params['host']}" + f":{connect_params.get('port')}" + ) def connection_made(self, conn: SSHClientConnection) -> None: """ @@ -77,8 +81,12 @@ def connection_made(self, conn: SSHClientConnection) -> None: Changes the connection_id from _CONNECTION_PENDING to a unique id repr. """ self._conn_event.clear() - _LOG.debug("%s: Connection made by %s: %s", current_thread().name, conn._options.env, conn) \ - # pylint: disable=protected-access + _LOG.debug( + "%s: Connection made by %s: %s", + current_thread().name, + conn._options.env, # pylint: disable=protected-access + conn, + ) self._connection_id = SshClient.id_from_connection(conn) self._connection = conn self._conn_event.set() @@ -88,9 +96,19 @@ def connection_lost(self, exc: Optional[Exception]) -> None: self._conn_event.clear() _LOG.debug("%s: %s", current_thread().name, "connection_lost") if exc is None: - _LOG.debug("%s: gracefully disconnected ssh from %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: gracefully disconnected ssh from %s: %s", + current_thread().name, + self._connection_id, + exc, + ) else: - _LOG.debug("%s: ssh connection lost on %s: %s", current_thread().name, self._connection_id, exc) + _LOG.debug( + "%s: ssh connection lost on %s: %s", + current_thread().name, + self._connection_id, + exc, + ) self._connection_id = SshClient._CONNECTION_LOST self._connection = None self._conn_event.set() @@ -144,7 +162,9 @@ def exit(self) -> None: warn(RuntimeWarning("SshClientCache lock was still held on exit.")) self._cache_lock.release() - async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientConnection, SshClient]: + async def get_client_connection( + self, connect_params: dict + ) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. @@ -167,13 +187,21 @@ async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientCo _LOG.debug("%s: Checking cached client %s", current_thread().name, connection_id) connection = await client.connection() if not connection: - _LOG.debug("%s: Removing stale client connection %s from cache.", current_thread().name, connection_id) + _LOG.debug( + "%s: Removing stale client connection %s from cache.", + current_thread().name, + connection_id, + ) self._cache.pop(connection_id) # Try to reconnect next. else: _LOG.debug("%s: Using cached client %s", current_thread().name, connection_id) if connection_id not in self._cache: - _LOG.debug("%s: Establishing client connection to %s", current_thread().name, connection_id) + _LOG.debug( + "%s: Establishing client connection to %s", + current_thread().name, + connection_id, + ) connection, client = await asyncssh.create_connection(SshClient, **connect_params) assert isinstance(client, SshClient) self._cache[connection_id] = (connection, client) @@ -182,7 +210,7 @@ async def get_client_connection(self, connect_params: dict) -> Tuple[SSHClientCo def cleanup(self) -> None: """Closes all cached connections.""" - for (connection, _) in self._cache.values(): + for connection, _ in self._cache.values(): connection.close() self._cache = {} @@ -220,21 +248,23 @@ class SshService(Service, metaclass=ABCMeta): _REQUEST_TIMEOUT: Optional[float] = None # seconds - def __init__(self, - config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__(config, global_config, parent, methods) # Make sure that the value we allow overriding on a per-connection # basis are present in the config so merge_parameters can do its thing. - self.config.setdefault('ssh_port', None) - assert isinstance(self.config['ssh_port'], (int, type(None))) - self.config.setdefault('ssh_username', None) - assert isinstance(self.config['ssh_username'], (str, type(None))) - self.config.setdefault('ssh_priv_key_path', None) - assert isinstance(self.config['ssh_priv_key_path'], (str, type(None))) + self.config.setdefault("ssh_port", None) + assert isinstance(self.config["ssh_port"], (int, type(None))) + self.config.setdefault("ssh_username", None) + assert isinstance(self.config["ssh_username"], (str, type(None))) + self.config.setdefault("ssh_priv_key_path", None) + assert isinstance(self.config["ssh_priv_key_path"], (str, type(None))) # None can be used to disable the request timeout. self._request_timeout = self.config.get("ssh_request_timeout", self._REQUEST_TIMEOUT) @@ -245,24 +275,25 @@ def __init__(self, # In general scripted commands shouldn't need a pty and having one # available can confuse some commands, though we may need to make # this configurable in the future. - 'request_pty': False, - # By default disable known_hosts checking (since most VMs expected to be dynamically created). - 'known_hosts': None, + "request_pty": False, + # By default disable known_hosts checking (since most VMs expected to be + # dynamically created). + "known_hosts": None, } - if 'ssh_known_hosts_file' in self.config: - self._connect_params['known_hosts'] = self.config.get("ssh_known_hosts_file", None) - if isinstance(self._connect_params['known_hosts'], str): - known_hosts_file = os.path.expanduser(self._connect_params['known_hosts']) + if "ssh_known_hosts_file" in self.config: + self._connect_params["known_hosts"] = self.config.get("ssh_known_hosts_file", None) + if isinstance(self._connect_params["known_hosts"], str): + known_hosts_file = os.path.expanduser(self._connect_params["known_hosts"]) if not os.path.exists(known_hosts_file): raise ValueError(f"ssh_known_hosts_file {known_hosts_file} does not exist") - self._connect_params['known_hosts'] = known_hosts_file - if self._connect_params['known_hosts'] is None: + self._connect_params["known_hosts"] = known_hosts_file + if self._connect_params["known_hosts"] is None: _LOG.info("%s known_hosts checking is disabled per config.", self) - if 'ssh_keepalive_interval' in self.config: - keepalive_internal = self.config.get('ssh_keepalive_interval') - self._connect_params['keepalive_interval'] = nullable(int, keepalive_internal) + if "ssh_keepalive_interval" in self.config: + keepalive_internal = self.config.get("ssh_keepalive_interval") + self._connect_params["keepalive_interval"] = nullable(int, keepalive_internal) def _enter_context(self) -> "SshService": # Start the background thread if it's not already running. @@ -272,9 +303,12 @@ def _enter_context(self) -> "SshService": super()._enter_context() return self - def _exit_context(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def _exit_context( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: # Stop the background thread if it's not needed anymore and potentially # cleanup the cache as well. assert self._in_context @@ -330,24 +364,26 @@ def _get_connect_params(self, params: dict) -> dict: # Start with the base config params. connect_params = self._connect_params.copy() - connect_params['host'] = params['ssh_hostname'] # required + connect_params["host"] = params["ssh_hostname"] # required - if params.get('ssh_port'): - connect_params['port'] = int(params.pop('ssh_port')) - elif self.config['ssh_port']: - connect_params['port'] = int(self.config['ssh_port']) + if params.get("ssh_port"): + connect_params["port"] = int(params.pop("ssh_port")) + elif self.config["ssh_port"]: + connect_params["port"] = int(self.config["ssh_port"]) - if 'ssh_username' in params: - connect_params['username'] = str(params.pop('ssh_username')) - elif self.config['ssh_username']: - connect_params['username'] = str(self.config['ssh_username']) + if "ssh_username" in params: + connect_params["username"] = str(params.pop("ssh_username")) + elif self.config["ssh_username"]: + connect_params["username"] = str(self.config["ssh_username"]) - priv_key_file: Optional[str] = params.get('ssh_priv_key_path', self.config['ssh_priv_key_path']) + priv_key_file: Optional[str] = params.get( + "ssh_priv_key_path", self.config["ssh_priv_key_path"] + ) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) if not os.path.exists(priv_key_file): raise ValueError(f"ssh_priv_key_path {priv_key_file} does not exist") - connect_params['client_keys'] = [priv_key_file] + connect_params["client_keys"] = [priv_key_file] return connect_params @@ -366,4 +402,6 @@ async def _get_client_connection(self, params: dict) -> Tuple[SSHClientConnectio The connection and client objects. """ assert self._in_context - return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection(self._get_connect_params(params)) + return await SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE.get_client_connection( + self._get_connect_params(params) + ) diff --git a/mlos_bench/mlos_bench/services/types/__init__.py b/mlos_bench/mlos_bench/services/types/__init__.py index e691d64514..e2d0cb55b5 100644 --- a/mlos_bench/mlos_bench/services/types/__init__.py +++ b/mlos_bench/mlos_bench/services/types/__init__.py @@ -18,12 +18,12 @@ from mlos_bench.services.types.remote_exec_type import SupportsRemoteExec __all__ = [ - 'SupportsAuth', - 'SupportsConfigLoading', - 'SupportsFileShareOps', - 'SupportsHostProvisioning', - 'SupportsLocalExec', - 'SupportsNetworkProvisioning', - 'SupportsRemoteConfig', - 'SupportsRemoteExec', + "SupportsAuth", + "SupportsConfigLoading", + "SupportsFileShareOps", + "SupportsHostProvisioning", + "SupportsLocalExec", + "SupportsNetworkProvisioning", + "SupportsRemoteConfig", + "SupportsRemoteExec", ] diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index c0b2d7335b..e29e5688ec 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -30,8 +30,7 @@ class SupportsConfigLoading(Protocol): """Protocol interface for helper functions to lookup and load configs.""" - def resolve_path(self, file_path: str, - extra_paths: Optional[Iterable[str]] = None) -> str: + def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = None) -> str: """ Prepend the suitable `_config_path` to `path` if the latter is not absolute. If `_config_path` is `None` or `path` is absolute, return `path` as is. @@ -49,7 +48,9 @@ def resolve_path(self, file_path: str, An actual path to the config or script. """ - def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) -> Union[dict, List[dict]]: + def load_config( + self, json_file_name: str, schema_type: Optional[ConfigSchema] + ) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input path is not absolute. This method is exported to be used as a service. @@ -67,12 +68,14 @@ def load_config(self, json_file_name: str, schema_type: Optional[ConfigSchema]) Free-format dictionary that contains the configuration. """ - def build_environment(self, # pylint: disable=too-many-arguments - config: dict, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> "Environment": + def build_environment( + self, # pylint: disable=too-many-arguments + config: dict, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> "Environment": """ Factory method for a new environment with a given config. @@ -102,12 +105,13 @@ def build_environment(self, # pylint: disable=too-many-arguments """ def load_environment_list( # pylint: disable=too-many-arguments - self, - json_file_name: str, - tunables: "TunableGroups", - global_config: Optional[dict] = None, - parent_args: Optional[Dict[str, TunableValue]] = None, - service: Optional["Service"] = None) -> List["Environment"]: + self, + json_file_name: str, + tunables: "TunableGroups", + global_config: Optional[dict] = None, + parent_args: Optional[Dict[str, TunableValue]] = None, + service: Optional["Service"] = None, + ) -> List["Environment"]: """ Load and build a list of environments from the config file. @@ -132,9 +136,12 @@ def load_environment_list( # pylint: disable=too-many-arguments A list of new benchmarking environments. """ - def load_services(self, json_file_names: Iterable[str], - global_config: Optional[Dict[str, Any]] = None, - parent: Optional["Service"] = None) -> "Service": + def load_services( + self, + json_file_names: Iterable[str], + global_config: Optional[Dict[str, Any]] = None, + parent: Optional["Service"] = None, + ) -> "Service": """ Read the configuration files and bundle all service methods from those configs into a single Service object. diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index 607f5cb674..c2ff153ac7 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -11,7 +11,9 @@ class SupportsFileShareOps(Protocol): """Protocol interface for file share operations.""" - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: """ Downloads contents from a remote share path to a local path. @@ -29,7 +31,9 @@ def download(self, params: dict, remote_path: str, local_path: str, recursive: b if True (the default), download the entire directory tree. """ - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py index 1be95aab22..1df0716fa1 100644 --- a/mlos_bench/mlos_bench/services/types/host_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/host_provisioner_type.py @@ -42,7 +42,8 @@ def wait_host_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Status params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Host/VM being deployed; otherwise, wait for successful deprovisioning. + If True, wait for Host/VM being deployed; otherwise, wait for successful + deprovisioning. Returns ------- diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 1c8f5f627e..9c4d2dc224 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -31,9 +31,12 @@ class SupportsLocalExec(Protocol): vs the target environment. Used in LocalEnv and provided by LocalExecService. """ - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, TunableValue]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, TunableValue]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: """ Execute the script lines from `script_lines` in a local process. @@ -54,7 +57,9 @@ def local_exec(self, script_lines: Iterable[str], A 3-tuple of return code, stdout, and stderr of the script process. """ - def temp_dir_context(self, path: Optional[str] = None) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: + def temp_dir_context( + self, path: Optional[str] = None + ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 5c1812f1f0..19e7b16350 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -42,7 +42,8 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta params : dict Flat dictionary of (key, value) pairs of tunable parameters. is_setup : bool - If True, wait for Network being deployed; otherwise, wait for successful deprovisioning. + If True, wait for Network being deployed; otherwise, wait for successful + deprovisioning. Returns ------- @@ -52,7 +53,9 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta Result is info on the operation runtime if SUCCEEDED, otherwise {}. """ - def deprovision_network(self, params: dict, ignore_errors: bool = True) -> Tuple["Status", dict]: + def deprovision_network( + self, params: dict, ignore_errors: bool = True + ) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_config_type.py b/mlos_bench/mlos_bench/services/types/remote_config_type.py index c25bc7b0ba..7e8d0a6e77 100644 --- a/mlos_bench/mlos_bench/services/types/remote_config_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_config_type.py @@ -14,8 +14,7 @@ class SupportsRemoteConfig(Protocol): """Protocol interface for configuring cloud services.""" - def configure(self, config: Dict[str, Any], - params: Dict[str, Any]) -> Tuple["Status", dict]: + def configure(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple["Status", dict]: """ Update the parameters of a SaaS service in the cloud. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index cba9e31b22..dd105f7a41 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -18,8 +18,9 @@ class SupportsRemoteExec(Protocol): on a remote host OS. """ - def remote_exec(self, script: Iterable[str], config: dict, - env_params: dict) -> Tuple["Status", dict]: + def remote_exec( + self, script: Iterable[str], config: dict, env_params: dict + ) -> Tuple["Status", dict]: """ Run a command on remote host OS. diff --git a/mlos_bench/mlos_bench/storage/__init__.py b/mlos_bench/mlos_bench/storage/__init__.py index a5bfeb7145..64e70c20f7 100644 --- a/mlos_bench/mlos_bench/storage/__init__.py +++ b/mlos_bench/mlos_bench/storage/__init__.py @@ -8,6 +8,6 @@ from mlos_bench.storage.storage_factory import from_config __all__ = [ - 'Storage', - 'from_config', + "Storage", + "from_config", ] diff --git a/mlos_bench/mlos_bench/storage/base_experiment_data.py b/mlos_bench/mlos_bench/storage/base_experiment_data.py index a6cb7d496a..60c27bc522 100644 --- a/mlos_bench/mlos_bench/storage/base_experiment_data.py +++ b/mlos_bench/mlos_bench/storage/base_experiment_data.py @@ -30,12 +30,15 @@ class ExperimentData(metaclass=ABCMeta): RESULT_COLUMN_PREFIX = "result." CONFIG_COLUMN_PREFIX = "config." - def __init__(self, *, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): self._experiment_id = experiment_id self._description = description self._root_env_config = root_env_config @@ -137,9 +140,9 @@ def default_tunable_config_id(self) -> Optional[int]: trials_items = sorted(self.trials.items()) if not trials_items: return None - for (_trial_id, trial) in trials_items: + for _trial_id, trial in trials_items: # Take the first config id marked as "defaults" when it was instantiated. - if strtobool(str(trial.metadata_dict.get('is_defaults', False))): + if strtobool(str(trial.metadata_dict.get("is_defaults", False))): return trial.tunable_config_id # Fallback (min trial_id) return trials_items[0][1].tunable_config_id @@ -154,7 +157,8 @@ def results_df(self) -> pandas.DataFrame: ------- results : pandas.DataFrame A DataFrame with configurations and results from all trials of the experiment. - Has columns [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] + Has columns + [trial_id, tunable_config_id, tunable_config_trial_group_id, ts_start, ts_end, status] followed by tunable config parameters (prefixed with "config.") and trial results (prefixed with "result."). The latter can be NULLs if the trial was not successful. diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 39b3bf851b..41c9df0e5e 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -27,10 +27,12 @@ class Storage(metaclass=ABCMeta): (e.g., SQLite or MLFLow). """ - def __init__(self, - config: Dict[str, Any], - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, + config: Dict[str, Any], + global_config: Optional[dict] = None, + service: Optional[Service] = None, + ): """ Create a new storage object. @@ -70,13 +72,16 @@ def experiments(self) -> Dict[str, ExperimentData]: """ @abstractmethod - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> 'Storage.Experiment': + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> "Storage.Experiment": """ Create a new experiment in the storage. @@ -113,23 +118,27 @@ class Experiment(metaclass=ABCMeta): This class is instantiated in the `Storage.experiment()` method. """ - def __init__(self, - *, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): self._tunables = tunables.copy() self._trial_id = trial_id self._experiment_id = experiment_id - (self._git_repo, self._git_commit, self._root_env_config) = get_git_info(root_env_config) + (self._git_repo, self._git_commit, self._root_env_config) = get_git_info( + root_env_config + ) self._description = description self._opt_targets = opt_targets self._in_context = False - def __enter__(self) -> 'Storage.Experiment': + def __enter__(self) -> "Storage.Experiment": """ Enter the context of the experiment. @@ -141,9 +150,12 @@ def __enter__(self) -> 'Storage.Experiment': self._in_context = True return self - def __exit__(self, exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Literal[False]: """ End the context of the experiment. @@ -154,8 +166,9 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], _LOG.debug("Finishing experiment: %s", self) else: assert exc_type and exc_val - _LOG.warning("Finishing experiment: %s", self, - exc_info=(exc_type, exc_val, exc_tb)) + _LOG.warning( + "Finishing experiment: %s", self, exc_info=(exc_type, exc_val, exc_tb) + ) assert self._in_context self._teardown(is_ok) self._in_context = False @@ -242,8 +255,10 @@ def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: """ @abstractmethod - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: """ Load (tunable values, benchmark scores, status) to warm-up the optimizer. @@ -263,7 +278,9 @@ def load(self, last_trial_id: int = -1, """ @abstractmethod - def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Storage.Trial']: + def pending_trials( + self, timestamp: datetime, *, running: bool + ) -> Iterator["Storage.Trial"]: """ Return an iterator over the pending trials that are scheduled to run on or before the specified timestamp. @@ -283,8 +300,12 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator['Sto """ @abstractmethod - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> 'Storage.Trial': + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> "Storage.Trial": """ Create a new experiment run in the storage. @@ -312,10 +333,16 @@ class Trial(metaclass=ABCMeta): This class is instantiated in the `Storage.Experiment.trial()` method. """ - def __init__(self, *, - tunables: TunableGroups, experiment_id: str, trial_id: int, - tunable_config_id: int, opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): self._tunables = tunables self._experiment_id = experiment_id self._trial_id = trial_id @@ -367,9 +394,9 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An return config @abstractmethod - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. @@ -393,14 +420,18 @@ def update(self, status: Status, timestamp: datetime, assert metrics is not None opt_targets = set(self._opt_targets.keys()) if not opt_targets.issubset(metrics.keys()): - _LOG.warning("Trial %s :: opt.targets missing: %s", - self, opt_targets.difference(metrics.keys())) + _LOG.warning( + "Trial %s :: opt.targets missing: %s", + self, + opt_targets.difference(metrics.keys()), + ) # raise ValueError() return metrics @abstractmethod - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + ) -> None: """ Save the experiment's telemetry data and intermediate status. diff --git a/mlos_bench/mlos_bench/storage/base_trial_data.py b/mlos_bench/mlos_bench/storage/base_trial_data.py index f9f7b93322..2a74d77e5e 100644 --- a/mlos_bench/mlos_bench/storage/base_trial_data.py +++ b/mlos_bench/mlos_bench/storage/base_trial_data.py @@ -29,13 +29,16 @@ class TrialData(metaclass=ABCMeta): tunable parameters). """ - def __init__(self, *, - experiment_id: str, - trial_id: int, - tunable_config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + experiment_id: str, + trial_id: int, + tunable_config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): self._experiment_id = experiment_id self._trial_id = trial_id self._tunable_config_id = tunable_config_id @@ -46,7 +49,10 @@ def __init__(self, *, self._status = status def __repr__(self) -> str: - return f"Trial :: {self._experiment_id}:{self._trial_id} cid:{self._tunable_config_id} {self._status.name}" + return ( + f"Trial :: {self._experiment_id}:{self._trial_id} " + f"cid:{self._tunable_config_id} {self._status.name}" + ) def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py index 0d58c20dc8..62751deb8e 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_data.py @@ -19,8 +19,7 @@ class TunableConfigData(metaclass=ABCMeta): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - tunable_config_id: int): + def __init__(self, *, tunable_config_id: int): self._tunable_config_id = tunable_config_id def __repr__(self) -> str: diff --git a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py index 62c01c3266..c01c7544b3 100644 --- a/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/base_tunable_config_trial_group_data.py @@ -25,10 +25,13 @@ class TunableConfigTrialGroupData(metaclass=ABCMeta): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): self._experiment_id = experiment_id self._tunable_config_id = tunable_config_id # can be lazily initialized as necessary: @@ -69,7 +72,10 @@ def __repr__(self) -> str: def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False - return self._tunable_config_id == other._tunable_config_id and self._experiment_id == other._experiment_id + return ( + self._tunable_config_id == other._tunable_config_id + and self._experiment_id == other._experiment_id + ) @property @abstractmethod diff --git a/mlos_bench/mlos_bench/storage/sql/__init__.py b/mlos_bench/mlos_bench/storage/sql/__init__.py index 86fd6de291..9d749ed35d 100644 --- a/mlos_bench/mlos_bench/storage/sql/__init__.py +++ b/mlos_bench/mlos_bench/storage/sql/__init__.py @@ -6,5 +6,5 @@ from mlos_bench.storage.sql.storage import SqlStorage __all__ = [ - 'SqlStorage', + "SqlStorage", ] diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index fed66b339d..5fdc6c0731 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -16,26 +16,28 @@ def get_trials( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> Dict[int, TrialData]: + engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None +) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. Used by both TunableConfigTrialGroupSqlData and ExperimentSqlData. """ - from mlos_bench.storage.sql.trial_data import ( - TrialSqlData, # pylint: disable=import-outside-toplevel,cyclic-import - ) + # pylint: disable=import-outside-toplevel,cyclic-import + from mlos_bench.storage.sql.trial_data import TrialSqlData + with engine.connect() as conn: # Build up sql a statement for fetching trials. - stmt = schema.trial.select().where( - schema.trial.c.exp_id == experiment_id, - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + stmt = ( + schema.trial.select() + .where( + schema.trial.c.exp_id == experiment_id, + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -59,10 +61,8 @@ def get_trials( def get_results_df( - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: Optional[int] = None) -> pandas.DataFrame: + engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None +) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted by tunable_config_id. @@ -72,15 +72,22 @@ def get_results_df( # pylint: disable=too-many-locals with engine.connect() as conn: # Compose a subquery to fetch the tunable_config_trial_group_id for each tunable config. - tunable_config_group_id_stmt = schema.trial.select().with_only_columns( - schema.trial.c.exp_id, - schema.trial.c.config_id, - func.min(schema.trial.c.trial_id).cast(Integer).label('tunable_config_trial_group_id'), - ).where( - schema.trial.c.exp_id == experiment_id, - ).group_by( - schema.trial.c.exp_id, - schema.trial.c.config_id, + tunable_config_group_id_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.exp_id, + schema.trial.c.config_id, + func.min(schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .group_by( + schema.trial.c.exp_id, + schema.trial.c.config_id, + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -90,18 +97,22 @@ def get_results_df( tunable_config_trial_group_id_subquery = tunable_config_group_id_stmt.subquery() # Get each trial's metadata. - cur_trials_stmt = select( - schema.trial, - tunable_config_trial_group_id_subquery, - ).where( - schema.trial.c.exp_id == experiment_id, - and_( - tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, - tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, - ), - ).order_by( - schema.trial.c.exp_id.asc(), - schema.trial.c.trial_id.asc(), + cur_trials_stmt = ( + select( + schema.trial, + tunable_config_trial_group_id_subquery, + ) + .where( + schema.trial.c.exp_id == experiment_id, + and_( + tunable_config_trial_group_id_subquery.c.exp_id == schema.trial.c.exp_id, + tunable_config_trial_group_id_subquery.c.config_id == schema.trial.c.config_id, + ), + ) + .order_by( + schema.trial.c.exp_id.asc(), + schema.trial.c.trial_id.asc(), + ) ) # Optionally restrict to those using a particular tunable config. if tunable_config_id is not None: @@ -110,39 +121,48 @@ def get_results_df( ) cur_trials = conn.execute(cur_trials_stmt) trials_df = pandas.DataFrame( - [( - row.trial_id, - utcify_timestamp(row.ts_start, origin="utc"), - utcify_nullable_timestamp(row.ts_end, origin="utc"), - row.config_id, - row.tunable_config_trial_group_id, - row.status, - ) for row in cur_trials.fetchall()], + [ + ( + row.trial_id, + utcify_timestamp(row.ts_start, origin="utc"), + utcify_nullable_timestamp(row.ts_end, origin="utc"), + row.config_id, + row.tunable_config_trial_group_id, + row.status, + ) + for row in cur_trials.fetchall() + ], columns=[ - 'trial_id', - 'ts_start', - 'ts_end', - 'tunable_config_id', - 'tunable_config_trial_group_id', - 'status', - ] + "trial_id", + "ts_start", + "ts_end", + "tunable_config_id", + "tunable_config_trial_group_id", + "status", + ], ) # Get each trial's config in wide format. - configs_stmt = schema.trial.select().with_only_columns( - schema.trial.c.trial_id, - schema.trial.c.config_id, - schema.config_param.c.param_id, - schema.config_param.c.param_value, - ).where( - schema.trial.c.exp_id == experiment_id, - ).join( - schema.config_param, - schema.config_param.c.config_id == schema.trial.c.config_id, - isouter=True - ).order_by( - schema.trial.c.trial_id, - schema.config_param.c.param_id, + configs_stmt = ( + schema.trial.select() + .with_only_columns( + schema.trial.c.trial_id, + schema.trial.c.config_id, + schema.config_param.c.param_id, + schema.config_param.c.param_value, + ) + .where( + schema.trial.c.exp_id == experiment_id, + ) + .join( + schema.config_param, + schema.config_param.c.config_id == schema.trial.c.config_id, + isouter=True, + ) + .order_by( + schema.trial.c.trial_id, + schema.config_param.c.param_id, + ) ) if tunable_config_id is not None: configs_stmt = configs_stmt.where( @@ -150,41 +170,73 @@ def get_results_df( ) configs = conn.execute(configs_stmt) configs_df = pandas.DataFrame( - [(row.trial_id, row.config_id, ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, row.param_value) - for row in configs.fetchall()], - columns=['trial_id', 'tunable_config_id', 'param', 'value'] + [ + ( + row.trial_id, + row.config_id, + ExperimentData.CONFIG_COLUMN_PREFIX + row.param_id, + row.param_value, + ) + for row in configs.fetchall() + ], + columns=["trial_id", "tunable_config_id", "param", "value"], ).pivot( - index=["trial_id", "tunable_config_id"], columns="param", values="value", + index=["trial_id", "tunable_config_id"], + columns="param", + values="value", ) - configs_df = configs_df.apply(pandas.to_numeric, errors='coerce').fillna(configs_df) # type: ignore[assignment] # (fp) + configs_df = configs_df.apply( # type: ignore[assignment] # (fp) + pandas.to_numeric, + errors="coerce", + ).fillna(configs_df) # Get each trial's results in wide format. - results_stmt = schema.trial_result.select().with_only_columns( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, - schema.trial_result.c.metric_value, - ).where( - schema.trial_result.c.exp_id == experiment_id, - ).order_by( - schema.trial_result.c.trial_id, - schema.trial_result.c.metric_id, + results_stmt = ( + schema.trial_result.select() + .with_only_columns( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + schema.trial_result.c.metric_value, + ) + .where( + schema.trial_result.c.exp_id == experiment_id, + ) + .order_by( + schema.trial_result.c.trial_id, + schema.trial_result.c.metric_id, + ) ) if tunable_config_id is not None: - results_stmt = results_stmt.join(schema.trial, and_( - schema.trial.c.exp_id == schema.trial_result.c.exp_id, - schema.trial.c.trial_id == schema.trial_result.c.trial_id, - schema.trial.c.config_id == tunable_config_id, - )) + results_stmt = results_stmt.join( + schema.trial, + and_( + schema.trial.c.exp_id == schema.trial_result.c.exp_id, + schema.trial.c.trial_id == schema.trial_result.c.trial_id, + schema.trial.c.config_id == tunable_config_id, + ), + ) results = conn.execute(results_stmt) results_df = pandas.DataFrame( - [(row.trial_id, ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, row.metric_value) - for row in results.fetchall()], - columns=['trial_id', 'metric', 'value'] + [ + ( + row.trial_id, + ExperimentData.RESULT_COLUMN_PREFIX + row.metric_id, + row.metric_value, + ) + for row in results.fetchall() + ], + columns=["trial_id", "metric", "value"], ).pivot( - index="trial_id", columns="metric", values="value", + index="trial_id", + columns="metric", + values="value", ) - results_df = results_df.apply(pandas.to_numeric, errors='coerce').fillna(results_df) # type: ignore[assignment] # (fp) + results_df = results_df.apply( # type: ignore[assignment] # (fp) + pandas.to_numeric, + errors="coerce", + ).fillna(results_df) # Concat the trials, configs, and results. - return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left") \ - .merge(results_df, on="trial_id", how="left") + return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( + results_df, on="trial_id", how="left" + ) diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index c96cd503be..443c4b2c82 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -25,15 +25,18 @@ class Experiment(Storage.Experiment): """Logic for retrieving and storing the results of a single experiment.""" - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - opt_targets: Dict[str, Literal['min', 'max']]): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + opt_targets: Dict[str, Literal["min", "max"]], + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -51,18 +54,22 @@ def _setup(self) -> None: # Get git info and the last trial ID for the experiment. # pylint: disable=not-callable exp_info = conn.execute( - self._schema.experiment.select().with_only_columns( + self._schema.experiment.select() + .with_only_columns( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, func.max(self._schema.trial.c.trial_id).label("trial_id"), - ).join( + ) + .join( self._schema.trial, self._schema.trial.c.exp_id == self._schema.experiment.c.exp_id, - isouter=True - ).where( + isouter=True, + ) + .where( self._schema.experiment.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.experiment.c.git_repo, self._schema.experiment.c.git_commit, self._schema.experiment.c.root_env_config, @@ -71,33 +78,47 @@ def _setup(self) -> None: if exp_info is None: _LOG.info("Start new experiment: %s", self._experiment_id) # It's a new experiment: create a record for it in the database. - conn.execute(self._schema.experiment.insert().values( - exp_id=self._experiment_id, - description=self._description, - git_repo=self._git_repo, - git_commit=self._git_commit, - root_env_config=self._root_env_config, - )) - conn.execute(self._schema.objectives.insert().values([ - { - "exp_id": self._experiment_id, - "optimization_target": opt_target, - "optimization_direction": opt_dir, - } - for (opt_target, opt_dir) in self.opt_targets.items() - ])) + conn.execute( + self._schema.experiment.insert().values( + exp_id=self._experiment_id, + description=self._description, + git_repo=self._git_repo, + git_commit=self._git_commit, + root_env_config=self._root_env_config, + ) + ) + conn.execute( + self._schema.objectives.insert().values( + [ + { + "exp_id": self._experiment_id, + "optimization_target": opt_target, + "optimization_direction": opt_dir, + } + for (opt_target, opt_dir) in self.opt_targets.items() + ] + ) + ) else: if exp_info.trial_id is not None: self._trial_id = exp_info.trial_id + 1 - _LOG.info("Continue experiment: %s last trial: %s resume from: %d", - self._experiment_id, exp_info.trial_id, self._trial_id) + _LOG.info( + "Continue experiment: %s last trial: %s resume from: %d", + self._experiment_id, + exp_info.trial_id, + self._trial_id, + ) # TODO: Sanity check that certain critical configs (e.g., # objectives) haven't changed to be incompatible such that a new # experiment should be started (possibly by prewarming with the # previous one). if exp_info.git_commit != self._git_commit: - _LOG.warning("Experiment %s git expected: %s %s", - self, exp_info.git_repo, exp_info.git_commit) + _LOG.warning( + "Experiment %s git expected: %s %s", + self, + exp_info.git_repo, + exp_info.git_commit, + ) def merge(self, experiment_ids: List[str]) -> None: _LOG.info("Merge: %s <- %s", self._experiment_id, experiment_ids) @@ -110,33 +131,42 @@ def load_tunable_config(self, config_id: int) -> Dict[str, Any]: def load_telemetry(self, trial_id: int) -> List[Tuple[datetime, str, Any]]: with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) ) # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. - return [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) - for row in cur_telemetry.fetchall()] + return [ + (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall() + ] - def load(self, last_trial_id: int = -1, - ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: + def load( + self, + last_trial_id: int = -1, + ) -> Tuple[List[int], List[dict], List[Optional[Dict[str, Any]]], List[Status]]: with self._engine.connect() as conn: cur_trials = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.trial_id, self._schema.trial.c.config_id, self._schema.trial.c.status, - ).where( + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id > last_trial_id, - self._schema.trial.c.status.in_(['SUCCEEDED', 'FAILED', 'TIMED_OUT']), - ).order_by( + self._schema.trial.c.status.in_(["SUCCEEDED", "FAILED", "TIMED_OUT"]), + ) + .order_by( self._schema.trial.c.trial_id.asc(), ) ) @@ -150,12 +180,21 @@ def load(self, last_trial_id: int = -1, stat = Status[trial.status] status.append(stat) trial_ids.append(trial.trial_id) - configs.append(self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id)) + configs.append( + self._get_key_val( + conn, self._schema.config_param, "param", config_id=trial.config_id + ) + ) if stat.is_succeeded(): - scores.append(self._get_key_val( - conn, self._schema.trial_result, "metric", - exp_id=self._experiment_id, trial_id=trial.trial_id)) + scores.append( + self._get_key_val( + conn, + self._schema.trial_result, + "metric", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) + ) else: scores.append(None) @@ -172,49 +211,60 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D select( column(f"{field}_id"), column(f"{field}_value"), - ).select_from(table).where( - *[column(key) == val for (key, val) in kwargs.items()] ) + .select_from(table) + .where(*[column(key) == val for (key, val) in kwargs.items()]) + ) + # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to + # avoid naming conflicts. + return dict( + row._tuple() for row in cur_result.fetchall() # pylint: disable=protected-access ) - # NOTE: `Row._tuple()` is NOT a protected member; the class uses `_` to avoid naming conflicts. - return dict(row._tuple() for row in cur_result.fetchall()) # pylint: disable=protected-access @staticmethod - def _save_params(conn: Connection, table: Table, - params: Dict[str, Any], **kwargs: Any) -> None: + def _save_params( + conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any + ) -> None: if not params: return - conn.execute(table.insert(), [ - { - **kwargs, - "param_id": key, - "param_value": nullable(str, val) - } - for (key, val) in params.items() - ]) + conn.execute( + table.insert(), + [ + {**kwargs, "param_id": key, "param_value": nullable(str, val)} + for (key, val) in params.items() + ], + ) def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Storage.Trial]: timestamp = utcify_timestamp(timestamp, origin="local") _LOG.info("Retrieve pending trials for: %s @ %s", self._experiment_id, timestamp) if running: - pending_status = ['PENDING', 'READY', 'RUNNING'] + pending_status = ["PENDING", "READY", "RUNNING"] else: - pending_status = ['PENDING'] + pending_status = ["PENDING"] with self._engine.connect() as conn: - cur_trials = conn.execute(self._schema.trial.select().where( - self._schema.trial.c.exp_id == self._experiment_id, - (self._schema.trial.c.ts_start.is_(None) | - (self._schema.trial.c.ts_start <= timestamp)), - self._schema.trial.c.ts_end.is_(None), - self._schema.trial.c.status.in_(pending_status), - )) + cur_trials = conn.execute( + self._schema.trial.select().where( + self._schema.trial.c.exp_id == self._experiment_id, + ( + self._schema.trial.c.ts_start.is_(None) + | (self._schema.trial.c.ts_start <= timestamp) + ), + self._schema.trial.c.ts_end.is_(None), + self._schema.trial.c.status.in_(pending_status), + ) + ) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", - config_id=trial.config_id) + conn, self._schema.config_param, "param", config_id=trial.config_id + ) config = self._get_key_val( - conn, self._schema.trial_param, "param", - exp_id=self._experiment_id, trial_id=trial.trial_id) + conn, + self._schema.trial_param, + "param", + exp_id=self._experiment_id, + trial_id=trial.trial_id, + ) yield Trial( engine=self._engine, schema=self._schema, @@ -233,42 +283,55 @@ def _get_config_id(self, conn: Connection, tunables: TunableGroups) -> int: If the config does not exist, create a new record for it. """ - config_hash = hashlib.sha256(str(tunables).encode('utf-8')).hexdigest() - cur_config = conn.execute(self._schema.config.select().where( - self._schema.config.c.config_hash == config_hash - )).fetchone() + config_hash = hashlib.sha256(str(tunables).encode("utf-8")).hexdigest() + cur_config = conn.execute( + self._schema.config.select().where(self._schema.config.c.config_hash == config_hash) + ).fetchone() if cur_config is not None: return int(cur_config.config_id) # mypy doesn't know it's always int # Config not found, create a new one: - config_id: int = conn.execute(self._schema.config.insert().values( - config_hash=config_hash)).inserted_primary_key[0] + config_id: int = conn.execute( + self._schema.config.insert().values(config_hash=config_hash) + ).inserted_primary_key[0] self._save_params( - conn, self._schema.config_param, + conn, + self._schema.config_param, {tunable.name: tunable.value for (tunable, _group) in tunables}, - config_id=config_id) + config_id=config_id, + ) return config_id - def new_trial(self, tunables: TunableGroups, ts_start: Optional[datetime] = None, - config: Optional[Dict[str, Any]] = None) -> Storage.Trial: + def new_trial( + self, + tunables: TunableGroups, + ts_start: Optional[datetime] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Storage.Trial: ts_start = utcify_timestamp(ts_start or datetime.now(UTC), origin="local") _LOG.debug("Create trial: %s:%d @ %s", self._experiment_id, self._trial_id, ts_start) with self._engine.begin() as conn: try: config_id = self._get_config_id(conn, tunables) - conn.execute(self._schema.trial.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - config_id=config_id, - ts_start=ts_start, - status='PENDING', - )) + conn.execute( + self._schema.trial.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + config_id=config_id, + ts_start=ts_start, + status="PENDING", + ) + ) # Note: config here is the framework config, not the target # environment config (i.e., tunables). if config is not None: self._save_params( - conn, self._schema.trial_param, config, - exp_id=self._experiment_id, trial_id=self._trial_id) + conn, + self._schema.trial_param, + config, + exp_id=self._experiment_id, + trial_id=self._trial_id, + ) trial = Trial( engine=self._engine, diff --git a/mlos_bench/mlos_bench/storage/sql/experiment_data.py b/mlos_bench/mlos_bench/storage/sql/experiment_data.py index b92885d1fd..f29b9fedda 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment_data.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment_data.py @@ -33,14 +33,17 @@ class ExperimentSqlData(ExperimentData): scripts and mlos_bench configuration files. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - description: str, - root_env_config: str, - git_repo: str, - git_commit: str): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + description: str, + root_env_config: str, + git_repo: str, + git_commit: str, + ): super().__init__( experiment_id=experiment_id, description=description, @@ -55,9 +58,11 @@ def __init__(self, *, def objectives(self) -> Dict[str, Literal["min", "max"]]: with self._engine.connect() as conn: objectives_db_data = conn.execute( - self._schema.objectives.select().where( + self._schema.objectives.select() + .where( self._schema.objectives.c.exp_id == self._experiment_id, - ).order_by( + ) + .order_by( self._schema.objectives.c.weight.desc(), self._schema.objectives.c.optimization_target.asc(), ) @@ -67,7 +72,8 @@ def objectives(self) -> Dict[str, Literal["min", "max"]]: for objective in objectives_db_data.fetchall() } - # TODO: provide a way to get individual data to avoid repeated bulk fetches where only small amounts of data is accessed. + # TODO: provide a way to get individual data to avoid repeated bulk fetches + # where only small amounts of data is accessed. # Or else make the TrialData object lazily populate. @property @@ -78,13 +84,17 @@ def trials(self) -> Dict[int, TrialData]: def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: with self._engine.connect() as conn: tunable_config_trial_groups = conn.execute( - self._schema.trial.select().with_only_columns( + self._schema.trial.select() + .with_only_columns( self._schema.trial.c.config_id, - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), # pylint: disable=not-callable + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -95,7 +105,7 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: schema=self._schema, experiment_id=self._experiment_id, tunable_config_id=tunable_config_trial_group.config_id, - tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, + tunable_config_trial_group_id=tunable_config_trial_group.tunable_config_trial_group_id, # pylint:disable=line-too-long # noqa ) for tunable_config_trial_group in tunable_config_trial_groups.fetchall() } @@ -104,11 +114,14 @@ def tunable_config_trial_groups(self) -> Dict[int, TunableConfigTrialGroupData]: def tunable_configs(self) -> Dict[int, TunableConfigData]: with self._engine.connect() as conn: tunable_configs = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) @@ -138,20 +151,28 @@ def default_tunable_config_id(self) -> Optional[int]: """ with self._engine.connect() as conn: query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial_param.select().with_only_columns( - func.min(self._schema.trial_param.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - "first_trial_id_with_defaults"), - ).where( + self._schema.trial_param.select() + .with_only_columns( + func.min(self._schema.trial_param.c.trial_id) + .cast(Integer) + .label("first_trial_id_with_defaults"), # pylint: disable=not-callable + ) + .where( self._schema.trial_param.c.exp_id == self._experiment_id, self._schema.trial_param.c.param_id == "is_defaults", - func.lower(self._schema.trial_param.c.param_value, type_=String).in_(["1", "true"]), - ).scalar_subquery() - ) + func.lower(self._schema.trial_param.c.param_value, type_=String).in_( + ["1", "true"] + ), + ) + .scalar_subquery() + ), ) ) min_default_trial_row = query_results.fetchone() @@ -160,17 +181,24 @@ def default_tunable_config_id(self) -> Optional[int]: return min_default_trial_row._tuple()[0] # fallback logic - assume minimum trial_id for experiment query_results = conn.execute( - self._schema.trial.select().with_only_columns( - self._schema.trial.c.config_id.cast(Integer).label('config_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + self._schema.trial.c.config_id.cast(Integer).label("config_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id.in_( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label("first_trial_id"), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("first_trial_id"), + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, - ).scalar_subquery() - ) + ) + .scalar_subquery() + ), ) ) min_trial_row = query_results.fetchone() diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 3443c9b810..717dc70c2a 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -74,7 +74,6 @@ def __init__(self, engine: Engine): Column("root_env_config", String(1024), nullable=False), Column("git_repo", String(1024), nullable=False), Column("git_commit", String(40), nullable=False), - PrimaryKeyConstraint("exp_id"), ) @@ -89,20 +88,25 @@ def __init__(self, engine: Engine): # Will need to adjust the insert and return values to support this # eventually. Column("weight", Float, nullable=True), - PrimaryKeyConstraint("exp_id", "optimization_target"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ) # A workaround for SQLAlchemy issue with autoincrement in DuckDB: if engine.dialect.name == "duckdb": - seq_config_id = Sequence('seq_config_id') - col_config_id = Column("config_id", Integer, seq_config_id, - server_default=seq_config_id.next_value(), - nullable=False, primary_key=True) + seq_config_id = Sequence("seq_config_id") + col_config_id = Column( + "config_id", + Integer, + seq_config_id, + server_default=seq_config_id.next_value(), + nullable=False, + primary_key=True, + ) else: - col_config_id = Column("config_id", Integer, nullable=False, - primary_key=True, autoincrement=True) + col_config_id = Column( + "config_id", Integer, nullable=False, primary_key=True, autoincrement=True + ) self.config = Table( "config", @@ -121,7 +125,6 @@ def __init__(self, engine: Engine): Column("ts_end", DateTime), # Should match the text IDs of `mlos_bench.environments.Status` enum: Column("status", String(self._STATUS_LEN), nullable=False), - PrimaryKeyConstraint("exp_id", "trial_id"), ForeignKeyConstraint(["exp_id"], [self.experiment.c.exp_id]), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), @@ -135,7 +138,6 @@ def __init__(self, engine: Engine): Column("config_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("config_id", "param_id"), ForeignKeyConstraint(["config_id"], [self.config.c.config_id]), ) @@ -149,10 +151,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("param_id", String(self._ID_LEN), nullable=False), Column("param_value", String(self._PARAM_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_status = Table( @@ -162,10 +164,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("status", String(self._STATUS_LEN), nullable=False), - UniqueConstraint("exp_id", "trial_id", "ts"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_result = Table( @@ -175,10 +177,10 @@ def __init__(self, engine: Engine): Column("trial_id", Integer, nullable=False), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) self.trial_telemetry = Table( @@ -189,15 +191,15 @@ def __init__(self, engine: Engine): Column("ts", DateTime(timezone=True), nullable=False, default="now"), Column("metric_id", String(self._ID_LEN), nullable=False), Column("metric_value", String(self._METRIC_VALUE_LEN)), - UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), - ForeignKeyConstraint(["exp_id", "trial_id"], - [self.trial.c.exp_id, self.trial.c.trial_id]), + ForeignKeyConstraint( + ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ), ) _LOG.debug("Schema: %s", self._meta) - def create(self) -> 'DbSchema': + def create(self) -> "DbSchema": """Create the DB schema.""" _LOG.info("Create the DB schema") self._meta.create_all(self._engine) diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index 4ac5116b70..f3a317db59 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -23,10 +23,9 @@ class SqlStorage(Storage): """An implementation of the Storage interface using SQLAlchemy backend.""" - def __init__(self, - config: dict, - global_config: Optional[dict] = None, - service: Optional[Service] = None): + def __init__( + self, config: dict, global_config: Optional[dict] = None, service: Optional[Service] = None + ): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) self._log_sql = self._config.pop("log_sql", False) @@ -43,7 +42,7 @@ def __init__(self, @property def _schema(self) -> DbSchema: """Lazily create schema upon first access.""" - if not hasattr(self, '_db_schema'): + if not hasattr(self, "_db_schema"): self._db_schema = DbSchema(self._engine).create() if _LOG.isEnabledFor(logging.DEBUG): _LOG.debug("DDL statements:\n%s", self._schema) @@ -52,13 +51,16 @@ def _schema(self) -> DbSchema: def __repr__(self) -> str: return self._repr - def experiment(self, *, - experiment_id: str, - trial_id: int, - root_env_config: str, - description: str, - tunables: TunableGroups, - opt_targets: Dict[str, Literal['min', 'max']]) -> Storage.Experiment: + def experiment( + self, + *, + experiment_id: str, + trial_id: int, + root_env_config: str, + description: str, + tunables: TunableGroups, + opt_targets: Dict[str, Literal["min", "max"]], + ) -> Storage.Experiment: return Experiment( engine=self._engine, schema=self._schema, diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 2a43c2c671..13233fd9a3 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -23,15 +23,18 @@ class Trial(Storage.Trial): """Store the results of a single run of the experiment in SQL database.""" - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunables: TunableGroups, - experiment_id: str, - trial_id: int, - config_id: int, - opt_targets: Dict[str, Literal['min', 'max']], - config: Optional[Dict[str, Any]] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + tunables: TunableGroups, + experiment_id: str, + trial_id: int, + config_id: int, + opt_targets: Dict[str, Literal["min", "max"]], + config: Optional[Dict[str, Any]] = None, + ): super().__init__( tunables=tunables, experiment_id=experiment_id, @@ -43,9 +46,9 @@ def __init__(self, *, self._engine = engine self._schema = schema - def update(self, status: Status, timestamp: datetime, - metrics: Optional[Dict[str, Any]] = None - ) -> Optional[Dict[str, Any]]: + def update( + self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") metrics = super().update(status, timestamp, metrics) @@ -55,13 +58,16 @@ def update(self, status: Status, timestamp: datetime, if status.is_completed(): # Final update of the status and ts_end: cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + ["SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] + ), + ) + .values( status=status.name, ts_end=timestamp, ) @@ -69,29 +75,37 @@ def update(self, status: Status, timestamp: datetime, if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." + - f" ({cur_status.rowcount} rows)") + f"Failed to update the status of the trial {self} to {status}." + + f" ({cur_status.rowcount} rows)" + ) if metrics: - conn.execute(self._schema.trial_result.insert().values([ - { - "exp_id": self._experiment_id, - "trial_id": self._trial_id, - "metric_id": key, - "metric_value": nullable(str, val), - } - for (key, val) in metrics.items() - ])) + conn.execute( + self._schema.trial_result.insert().values( + [ + { + "exp_id": self._experiment_id, + "trial_id": self._trial_id, + "metric_id": key, + "metric_value": nullable(str, val), + } + for (key, val) in metrics.items() + ] + ) + ) else: # Update of the status and ts_start when starting the trial: assert metrics is None, f"Unexpected metrics for status: {status}" cur_status = conn.execute( - self._schema.trial.update().where( + self._schema.trial.update() + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.trial_id == self._trial_id, self._schema.trial.c.ts_end.is_(None), self._schema.trial.c.status.notin_( - ['RUNNING', 'SUCCEEDED', 'CANCELED', 'FAILED', 'TIMED_OUT']), - ).values( + ["RUNNING", "SUCCEEDED", "CANCELED", "FAILED", "TIMED_OUT"] + ), + ) + .values( status=status.name, ts_start=timestamp, ) @@ -104,8 +118,9 @@ def update(self, status: Status, timestamp: datetime, raise return metrics - def update_telemetry(self, status: Status, timestamp: datetime, - metrics: List[Tuple[datetime, str, Any]]) -> None: + def update_telemetry( + self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + ) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") @@ -116,16 +131,18 @@ def update_telemetry(self, status: Status, timestamp: datetime, # See Also: comments in with self._engine.begin() as conn: self._update_status(conn, status, timestamp) - for (metric_ts, key, val) in metrics: + for metric_ts, key, val in metrics: with self._engine.begin() as conn: try: - conn.execute(self._schema.trial_telemetry.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=metric_ts, - metric_id=key, - metric_value=nullable(str, val), - )) + conn.execute( + self._schema.trial_telemetry.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=metric_ts, + metric_id=key, + metric_value=nullable(str, val), + ) + ) except IntegrityError as ex: _LOG.warning("Record already exists: %s :: %s", (metric_ts, key, val), ex) @@ -138,12 +155,15 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") try: - conn.execute(self._schema.trial_status.insert().values( - exp_id=self._experiment_id, - trial_id=self._trial_id, - ts=timestamp, - status=status.name, - )) + conn.execute( + self._schema.trial_status.insert().values( + exp_id=self._experiment_id, + trial_id=self._trial_id, + ts=timestamp, + status=status.name, + ) + ) except IntegrityError as ex: - _LOG.warning("Status with that timestamp already exists: %s %s :: %s", - self, timestamp, ex) + _LOG.warning( + "Status with that timestamp already exists: %s %s :: %s", self, timestamp, ex + ) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 18fc0b46ff..690492585b 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -25,15 +25,18 @@ class TrialSqlData(TrialData): """An interface to access the trial data stored in the SQL DB.""" - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - trial_id: int, - config_id: int, - ts_start: datetime, - ts_end: Optional[datetime], - status: Status): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + trial_id: int, + config_id: int, + ts_start: datetime, + ts_end: Optional[datetime], + status: Status, + ): super().__init__( experiment_id=experiment_id, trial_id=trial_id, @@ -52,8 +55,9 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ - return TunableConfigSqlData(engine=self._engine, schema=self._schema, - tunable_config_id=self._tunable_config_id) + return TunableConfigSqlData( + engine=self._engine, schema=self._schema, tunable_config_id=self._tunable_config_id + ) @property def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": @@ -64,35 +68,44 @@ def tunable_config_trial_group(self) -> "TunableConfigTrialGroupData": from mlos_bench.storage.sql.tunable_config_trial_group_data import ( TunableConfigTrialGroupSqlData, ) - return TunableConfigTrialGroupSqlData(engine=self._engine, schema=self._schema, - experiment_id=self._experiment_id, - tunable_config_id=self._tunable_config_id) + + return TunableConfigTrialGroupSqlData( + engine=self._engine, + schema=self._schema, + experiment_id=self._experiment_id, + tunable_config_id=self._tunable_config_id, + ) @property def results_df(self) -> pandas.DataFrame: """Retrieve the trials' results from the storage.""" with self._engine.connect() as conn: cur_results = conn.execute( - self._schema.trial_result.select().where( + self._schema.trial_result.select() + .where( self._schema.trial_result.c.exp_id == self._experiment_id, - self._schema.trial_result.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_result.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_result.c.metric_id, ) ) return pandas.DataFrame( [(row.metric_id, row.metric_value) for row in cur_results.fetchall()], - columns=['metric', 'value']) + columns=["metric", "value"], + ) @property def telemetry_df(self) -> pandas.DataFrame: """Retrieve the trials' telemetry from the storage.""" with self._engine.connect() as conn: cur_telemetry = conn.execute( - self._schema.trial_telemetry.select().where( + self._schema.trial_telemetry.select() + .where( self._schema.trial_telemetry.c.exp_id == self._experiment_id, - self._schema.trial_telemetry.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_telemetry.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_telemetry.c.ts, self._schema.trial_telemetry.c.metric_id, ) @@ -100,8 +113,12 @@ def telemetry_df(self) -> pandas.DataFrame: # Not all storage backends store the original zone info. # We try to ensure data is entered in UTC and augment it on return again here. return pandas.DataFrame( - [(utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) for row in cur_telemetry.fetchall()], - columns=['ts', 'metric', 'value']) + [ + (utcify_timestamp(row.ts, origin="utc"), row.metric_id, row.metric_value) + for row in cur_telemetry.fetchall() + ], + columns=["ts", "metric", "value"], + ) @property def metadata_df(self) -> pandas.DataFrame: @@ -112,13 +129,16 @@ def metadata_df(self) -> pandas.DataFrame: """ with self._engine.connect() as conn: cur_params = conn.execute( - self._schema.trial_param.select().where( + self._schema.trial_param.select() + .where( self._schema.trial_param.c.exp_id == self._experiment_id, - self._schema.trial_param.c.trial_id == self._trial_id - ).order_by( + self._schema.trial_param.c.trial_id == self._trial_id, + ) + .order_by( self._schema.trial_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_params.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py index 616d5fe823..40225039be 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_data.py @@ -18,10 +18,7 @@ class TunableConfigSqlData(TunableConfigData): A configuration in this context is the set of tunable parameter values. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - tunable_config_id: int): + def __init__(self, *, engine: Engine, schema: DbSchema, tunable_config_id: int): super().__init__(tunable_config_id=tunable_config_id) self._engine = engine self._schema = schema @@ -30,12 +27,13 @@ def __init__(self, *, def config_df(self) -> pandas.DataFrame: with self._engine.connect() as conn: cur_config = conn.execute( - self._schema.config_param.select().where( - self._schema.config_param.c.config_id == self._tunable_config_id - ).order_by( + self._schema.config_param.select() + .where(self._schema.config_param.c.config_id == self._tunable_config_id) + .order_by( self._schema.config_param.c.param_id, ) ) return pandas.DataFrame( [(row.param_id, row.param_value) for row in cur_config.fetchall()], - columns=['parameter', 'value']) + columns=["parameter", "value"], + ) diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index 457e81e7c0..31a6df5879 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -31,12 +31,15 @@ class TunableConfigTrialGroupSqlData(TunableConfigTrialGroupData): (e.g., for repeats), which we call a (tunable) config trial group. """ - def __init__(self, *, - engine: Engine, - schema: DbSchema, - experiment_id: str, - tunable_config_id: int, - tunable_config_trial_group_id: Optional[int] = None): + def __init__( + self, + *, + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: int, + tunable_config_trial_group_id: Optional[int] = None, + ): super().__init__( experiment_id=experiment_id, tunable_config_id=tunable_config_id, @@ -49,20 +52,26 @@ def _get_tunable_config_trial_group_id(self) -> int: """Retrieve the trial's tunable_config_trial_group_id from the storage.""" with self._engine.connect() as conn: tunable_config_trial_group = conn.execute( - self._schema.trial.select().with_only_columns( - func.min(self._schema.trial.c.trial_id).cast(Integer).label( # pylint: disable=not-callable - 'tunable_config_trial_group_id'), - ).where( + self._schema.trial.select() + .with_only_columns( + func.min(self._schema.trial.c.trial_id) + .cast(Integer) + .label("tunable_config_trial_group_id"), # pylint: disable=not-callable + ) + .where( self._schema.trial.c.exp_id == self._experiment_id, self._schema.trial.c.config_id == self._tunable_config_id, - ).group_by( + ) + .group_by( self._schema.trial.c.exp_id, self._schema.trial.c.config_id, ) ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[0] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[ + 0 + ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy @property def tunable_config(self) -> TunableConfigData: @@ -83,8 +92,12 @@ def trials(self) -> Dict[int, "TrialData"]: trials : Dict[int, TrialData] A dictionary of the trials' data, keyed by trial id. """ - return common.get_trials(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_trials( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) @property def results_df(self) -> pandas.DataFrame: - return common.get_results_df(self._engine, self._schema, self._experiment_id, self._tunable_config_id) + return common.get_results_df( + self._engine, self._schema, self._experiment_id, self._tunable_config_id + ) diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index d1b2547876..2de66a9aab 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -11,9 +11,9 @@ from mlos_bench.storage.base_storage import Storage -def from_config(config_file: str, - global_configs: Optional[List[str]] = None, - **kwargs: Any) -> Storage: +def from_config( + config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any +) -> Storage: """ Create a new storage object from JSON5 config file. @@ -34,7 +34,7 @@ def from_config(config_file: str, config_path: List[str] = kwargs.get("config_path", []) config_loader = ConfigPersistenceService({"config_path": config_path}) global_config = {} - for fname in (global_configs or []): + for fname in global_configs or []: config = config_loader.load_config(fname, ConfigSchema.GLOBALS) global_config.update(config) config_path += config.get("config_path", []) diff --git a/mlos_bench/mlos_bench/storage/util.py b/mlos_bench/mlos_bench/storage/util.py index 1ac48b4fab..173f7d95d6 100644 --- a/mlos_bench/mlos_bench/storage/util.py +++ b/mlos_bench/mlos_bench/storage/util.py @@ -23,16 +23,18 @@ def kv_df_to_dict(dataframe: pandas.DataFrame) -> Dict[str, Optional[TunableValu A dataframe with exactly two columns, 'parameter' (or 'metric') and 'value', where 'parameter' is a string and 'value' is some TunableValue or None. """ - if dataframe.columns.tolist() == ['metric', 'value']: + if dataframe.columns.tolist() == ["metric", "value"]: dataframe = dataframe.copy() - dataframe.rename(columns={'metric': 'parameter'}, inplace=True) - assert dataframe.columns.tolist() == ['parameter', 'value'] + dataframe.rename(columns={"metric": "parameter"}, inplace=True) + assert dataframe.columns.tolist() == ["parameter", "value"] data = {} - for _, row in dataframe.astype('O').iterrows(): - if not isinstance(row['value'], TunableValueTypeTuple): + for _, row in dataframe.astype("O").iterrows(): + if not isinstance(row["value"], TunableValueTypeTuple): raise TypeError(f"Invalid column type: {type(row['value'])} value: {row['value']}") - assert isinstance(row['parameter'], str) - if row['parameter'] in data: + assert isinstance(row["parameter"], str) + if row["parameter"] in data: raise ValueError(f"Duplicate parameter '{row['parameter']}' in dataframe") - data[row['parameter']] = try_parse_val(row['value']) if isinstance(row['value'], str) else row['value'] + data[row["parameter"]] = ( + try_parse_val(row["value"]) if isinstance(row["value"], str) else row["value"] + ) return data diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index dee543357f..4fca4fc449 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -30,26 +30,34 @@ None, ] ZONE_INFO: List[Optional[tzinfo]] = [ - nullable(pytz.timezone, zone_name) - for zone_name in ZONE_NAMES + nullable(pytz.timezone, zone_name) for zone_name in ZONE_NAMES ] # A decorator for tests that require docker. # Use with @requires_docker above a test_...() function. -DOCKER = shutil.which('docker') +DOCKER = shutil.which("docker") if DOCKER: - cmd = run("docker builder inspect default || docker buildx inspect default", shell=True, check=False, capture_output=True) + cmd = run( + "docker builder inspect default || docker buildx inspect default", + shell=True, + check=False, + capture_output=True, + ) stdout = cmd.stdout.decode() - if cmd.returncode != 0 or not any(line for line in stdout.splitlines() if 'Platform' in line and 'linux' in line): + if cmd.returncode != 0 or not any( + line for line in stdout.splitlines() if "Platform" in line and "linux" in line + ): debug("Docker is available but missing support for targeting linux platform.") DOCKER = None -requires_docker = pytest.mark.skipif(not DOCKER, reason='Docker with Linux support is not available on this system.') +requires_docker = pytest.mark.skipif( + not DOCKER, reason="Docker with Linux support is not available on this system." +) # A decorator for tests that require ssh. # Use with @requires_ssh above a test_...() function. -SSH = shutil.which('ssh') -requires_ssh = pytest.mark.skipif(not SSH, reason='ssh is not available on this system.') +SSH = shutil.which("ssh") +requires_ssh = pytest.mark.skipif(not SSH, reason="ssh is not available on this system.") # A common seed to use to avoid tracking down race conditions and intermingling # issues of seeds across tests that run in non-deterministic parallel orders. @@ -126,8 +134,17 @@ def are_dir_trees_equal(dir1: str, dir2: str) -> bool: """ # See Also: https://stackoverflow.com/a/6681395 dirs_cmp = filecmp.dircmp(dir1, dir2) - if len(dirs_cmp.left_only) > 0 or len(dirs_cmp.right_only) > 0 or len(dirs_cmp.funny_files) > 0: - warning(f"Found differences in dir trees {dir1}, {dir2}:\n{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}") + if ( + len(dirs_cmp.left_only) > 0 + or len(dirs_cmp.right_only) > 0 + or len(dirs_cmp.funny_files) > 0 + ): + warning( + ( + f"Found differences in dir trees {dir1}, {dir2}:\n" + f"{dirs_cmp.diff_files}\n{dirs_cmp.funny_files}" + ) + ) return False (_, mismatch, errors) = filecmp.cmpfiles(dir1, dir2, dirs_cmp.common_files, shallow=False) if len(mismatch) > 0 or len(errors) > 0: diff --git a/mlos_bench/mlos_bench/tests/config/__init__.py b/mlos_bench/mlos_bench/tests/config/__init__.py index ecd1c1ba58..b2b6146a56 100644 --- a/mlos_bench/mlos_bench/tests/config/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/__init__.py @@ -19,9 +19,11 @@ BUILTIN_TEST_CONFIG_PATH = str(files("mlos_bench.tests.config").joinpath("")).replace("\\", "/") -def locate_config_examples(root_dir: str, - config_examples_dir: str, - examples_filter: Optional[Callable[[List[str]], List[str]]] = None) -> List[str]: +def locate_config_examples( + root_dir: str, + config_examples_dir: str, + examples_filter: Optional[Callable[[List[str]], List[str]]] = None, +) -> List[str]: """ Locates all config examples in the given directory. diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 1bea4f4369..7add370011 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -41,7 +41,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ), *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -49,7 +51,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: # pragma: no cover +def test_load_cli_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -59,7 +63,7 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic assert isinstance(config_paths, list) config_paths.reverse() for path in config_paths: - config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access + config_loader_service._config_path.insert(0, path) # pylint: disable=protected-access # Foreach arg that references another file, see if we can at least load that too. args_to_skip = { @@ -96,7 +100,9 @@ def test_load_cli_config_examples(config_loader_service: ConfigPersistenceServic @pytest.mark.parametrize("config_path", configs) -def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_cli_config_examples_via_launcher( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) assert isinstance(config, dict) @@ -104,10 +110,13 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Try to load the CLI config by instantiating a launcher. # To do this we need to make sure to give it a few extra paths and globals # to look for for our examples. - cli_args = f"--config {config_path}" + \ - f" --config-path {files('mlos_bench.config')} --config-path {files('mlos_bench.tests.config')}" + \ - f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" + \ + cli_args = ( + f"--config {config_path}" + f" --config-path {files('mlos_bench.config')} " + f" --config-path {files('mlos_bench.tests.config')}" + f" --config-path {path_join(str(files('mlos_bench.tests.config')), 'globals')}" f" --globals {files('mlos_bench.tests.config')}/experiments/experiment_test_config.jsonc" + ) launcher = Launcher(description=__name__, long_text=config_path, argv=cli_args.split()) assert launcher @@ -118,15 +127,16 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers assert isinstance(config_paths, list) for path in config_paths: # Note: Checks that the order is maintained are handled in launcher_parse_args.py - assert any(config_path.endswith(path) for config_path in launcher.config_loader.config_paths), \ - f"Expected {path} to be in {launcher.config_loader.config_paths}" + assert any( + config_path.endswith(path) for config_path in launcher.config_loader.config_paths + ), f"Expected {path} to be in {launcher.config_loader.config_paths}" - if 'experiment_id' in config: - assert launcher.global_config['experiment_id'] == config['experiment_id'] - if 'trial_id' in config: - assert launcher.global_config['trial_id'] == config['trial_id'] + if "experiment_id" in config: + assert launcher.global_config["experiment_id"] == config["experiment_id"] + if "trial_id" in config: + assert launcher.global_config["trial_id"] == config["trial_id"] - expected_log_level = logging.getLevelName(config.get('log_level', "INFO")) + expected_log_level = logging.getLevelName(config.get("log_level", "INFO")) if isinstance(expected_log_level, int): expected_log_level = logging.getLevelName(expected_log_level) current_log_level = logging.getLevelName(logging.root.getEffectiveLevel()) @@ -134,7 +144,7 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # TODO: Check that the log_file handler is set correctly. - expected_teardown = config.get('teardown', True) + expected_teardown = config.get("teardown", True) assert launcher.teardown == expected_teardown # Note: Testing of "globals" processing handled in launcher_parse_args_test.py @@ -143,22 +153,30 @@ def test_load_cli_config_examples_via_launcher(config_loader_service: ConfigPers # Launcher loaded the expected types as well. assert isinstance(launcher.environment, Environment) - env_config = launcher.config_loader.load_config(config["environment"], ConfigSchema.ENVIRONMENT) + env_config = launcher.config_loader.load_config( + config["environment"], ConfigSchema.ENVIRONMENT + ) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: - opt_config = launcher.config_loader.load_config(config["optimizer"], ConfigSchema.OPTIMIZER) + opt_config = launcher.config_loader.load_config( + config["optimizer"], ConfigSchema.OPTIMIZER + ) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: - storage_config = launcher.config_loader.load_config(config["storage"], ConfigSchema.STORAGE) + storage_config = launcher.config_loader.load_config( + config["storage"], ConfigSchema.STORAGE + ) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: - scheduler_config = launcher.config_loader.load_config(config["scheduler"], ConfigSchema.SCHEDULER) + scheduler_config = launcher.config_loader.load_config( + config["scheduler"], ConfigSchema.SCHEDULER + ) assert check_class_name(launcher.scheduler, scheduler_config["class"]) # TODO: Check that the launcher assigns the tunables values as expected. diff --git a/mlos_bench/mlos_bench/tests/config/conftest.py b/mlos_bench/mlos_bench/tests/config/conftest.py index 6f8cebb910..5f9167dc85 100644 --- a/mlos_bench/mlos_bench/tests/config/conftest.py +++ b/mlos_bench/mlos_bench/tests/config/conftest.py @@ -20,9 +20,11 @@ @pytest.fixture def config_loader_service() -> ConfigPersistenceService: """Config loader service fixture.""" - return ConfigPersistenceService(config={ - "config_path": [ - str(files("mlos_bench.tests.config")), - path_join(str(files("mlos_bench.tests.config")), "globals"), - ] - }) + return ConfigPersistenceService( + config={ + "config_path": [ + str(files("mlos_bench.tests.config")), + path_join(str(files("mlos_bench.tests.config")), "globals"), + ] + } + ) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 1b9103c5af..c7d0f9ba44 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -25,16 +25,24 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" - configs_to_filter = [config_path for config_path in configs_to_filter if not config_path.endswith("-tunables.jsonc")] + configs_to_filter = [ + config_path + for config_path in configs_to_filter + if not config_path.endswith("-tunables.jsonc") + ] return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) for env in envs: @@ -42,11 +50,15 @@ def test_load_environment_config_examples(config_loader_service: ConfigPersisten assert isinstance(env, Environment) -def load_environment_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> List[Environment]: +def load_environment_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. - global_config = config_loader_service.load_config("experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS) - global_config.setdefault('trial_id', 1) # normally populated by Launcher + global_config = config_loader_service.load_config( + "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS + ) + global_config.setdefault("trial_id", 1) # normally populated by Launcher # Make sure we have the required services for the envs being used. mock_service_configs = [ @@ -58,24 +70,34 @@ def load_environment_config_examples(config_loader_service: ConfigPersistenceSer "services/remote/mock/mock_auth_service.jsonc", ] - tunable_groups = TunableGroups() # base tunable groups that all others get built on + tunable_groups = TunableGroups() # base tunable groups that all others get built on for mock_service_config_path in mock_service_configs: - mock_service_config = config_loader_service.load_config(mock_service_config_path, ConfigSchema.SERVICE) - config_loader_service.register(config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service).export()) + mock_service_config = config_loader_service.load_config( + mock_service_config_path, ConfigSchema.SERVICE + ) + config_loader_service.register( + config_loader_service.build_service( + config=mock_service_config, parent=config_loader_service + ).export() + ) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service) + config_path, tunable_groups, global_config, service=config_loader_service + ) return envs -composite_configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/") +composite_configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" +) assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) -def test_load_composite_env_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_composite_env_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) assert len(envs) == 1 @@ -88,17 +110,23 @@ def test_load_composite_env_config_examples(config_loader_service: ConfigPersist assert child_env.tunable_params is not None checked_child_env_groups = set() - for (child_tunable, child_group) in child_env.tunable_params: + for child_tunable, child_group in child_env.tunable_params: # Lookup that tunable in the composite env. assert child_tunable in composite_env.tunable_params - (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable(child_tunable) - assert child_tunable is composite_tunable # Check that the tunables are the same object. + (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( + child_tunable + ) + assert ( + child_tunable is composite_tunable + ) # Check that the tunables are the same object. if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) - # Check that when we change a child env, it's value is reflected in the composite env as well. - # That is to say, they refer to the same objects, despite having potentially been loaded from separate configs. + # Check that when we change a child env, it's value is reflected in the + # composite env as well. + # That is to say, they refer to the same objects, despite having + # potentially been loaded from separate configs. if child_tunable.is_categorical: old_cat_value = child_tunable.category assert child_tunable.value == old_cat_value diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index 708bb0f55c..c7525a2960 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -26,16 +26,34 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ - # *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, "experiments", filter_configs), - *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), - *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, "experiments", filter_configs), + # *locate_config_examples( + # ConfigPersistenceService.BUILTIN_CONFIG_PATH, + # CONFIG_TYPE, + # filter_configs, + # ), + *locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + "experiments", + filter_configs, + ), + *locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, + ), + *locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + "experiments", + filter_configs, + ), ] assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_globals_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_globals_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index ad4dae94f8..e08f4d593b 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -28,12 +28,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_optimizer_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_optimizer_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index 5f2f24e519..bd6921d8c2 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -30,14 +30,17 @@ def __hash__(self) -> int: # The different type of schema test cases we expect to have. -_SCHEMA_TEST_TYPES = {x.test_case_type: x for x in ( - SchemaTestType(test_case_type='good', test_case_subtypes={'full', 'partial'}), - SchemaTestType(test_case_type='bad', test_case_subtypes={'invalid', 'unhandled'}), -)} +_SCHEMA_TEST_TYPES = { + x.test_case_type: x + for x in ( + SchemaTestType(test_case_type="good", test_case_subtypes={"full", "partial"}), + SchemaTestType(test_case_type="bad", test_case_subtypes={"invalid", "unhandled"}), + ) +} @dataclass -class SchemaTestCaseInfo(): +class SchemaTestCaseInfo: """Some basic info about a schema test case.""" config: Dict[str, Any] @@ -54,15 +57,18 @@ def check_schema_dir_layout(test_cases_root: str) -> None: extra configs or test cases. """ for test_case_dir in os.listdir(test_cases_root): - if test_case_dir == 'README.md': + if test_case_dir == "README.md": continue if test_case_dir not in _SCHEMA_TEST_TYPES: raise NotImplementedError(f"Unhandled test case type: {test_case_dir}") for test_case_subdir in os.listdir(os.path.join(test_cases_root, test_case_dir)): - if test_case_subdir == 'README.md': + if test_case_subdir == "README.md": continue if test_case_subdir not in _SCHEMA_TEST_TYPES[test_case_dir].test_case_subtypes: - raise NotImplementedError(f"Unhandled test case subtype {test_case_subdir} for test case type {test_case_dir}") + raise NotImplementedError( + f"Unhandled test case subtype {test_case_subdir} " + f"for test case type {test_case_dir}" + ) @dataclass @@ -76,15 +82,21 @@ class TestCases: def get_schema_test_cases(test_cases_root: str) -> TestCases: """Gets a dict of schema test cases from the given root.""" - test_cases = TestCases(by_path={}, - by_type={x: {} for x in _SCHEMA_TEST_TYPES}, - by_subtype={y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes}) + test_cases = TestCases( + by_path={}, + by_type={x: {} for x in _SCHEMA_TEST_TYPES}, + by_subtype={ + y: {} for x in _SCHEMA_TEST_TYPES for y in _SCHEMA_TEST_TYPES[x].test_case_subtypes + }, + ) check_schema_dir_layout(test_cases_root) # Note: we sort the test cases so that we can deterministically test them in parallel. - for (test_case_type, schema_test_type) in _SCHEMA_TEST_TYPES.items(): + for test_case_type, schema_test_type in _SCHEMA_TEST_TYPES.items(): for test_case_subtype in schema_test_type.test_case_subtypes: - for test_case_file in locate_config_examples(test_cases_root, os.path.join(test_case_type, test_case_subtype)): - with open(test_case_file, mode='r', encoding='utf-8') as test_case_fh: + for test_case_file in locate_config_examples( + test_cases_root, os.path.join(test_case_type, test_case_subtype) + ): + with open(test_case_file, mode="r", encoding="utf-8") as test_case_fh: try: test_case_info = SchemaTestCaseInfo( config=json5.load(test_case_fh), @@ -93,8 +105,12 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: test_case_subtype=test_case_subtype, ) test_cases.by_path[test_case_info.test_case_file] = test_case_info - test_cases.by_type[test_case_info.test_case_type][test_case_info.test_case_file] = test_case_info - test_cases.by_subtype[test_case_info.test_case_subtype][test_case_info.test_case_file] = test_case_info + test_cases.by_type[test_case_info.test_case_type][ + test_case_info.test_case_file + ] = test_case_info + test_cases.by_subtype[test_case_info.test_case_subtype][ + test_case_info.test_case_file + ] = test_case_info except Exception as ex: raise RuntimeError("Failed to load test case: " + test_case_file) from ex assert test_cases @@ -106,7 +122,9 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: return test_cases -def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_against_schema( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """ Checks the given test case against the given schema. @@ -131,7 +149,9 @@ def check_test_case_against_schema(test_case: SchemaTestCaseInfo, schema_type: C raise NotImplementedError(f"Unknown test case type: {test_case.test_case_type}") -def check_test_case_config_with_extra_param(test_case: SchemaTestCaseInfo, schema_type: ConfigSchema) -> None: +def check_test_case_config_with_extra_param( + test_case: SchemaTestCaseInfo, schema_type: ConfigSchema +) -> None: """Checks that the config fails to validate if extra params are present in certain places. """ diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 3ef2b56654..404602b724 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -24,14 +24,16 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_cli_configs_against_schema(test_case_name: str) -> None: """Checks that the CLI config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.CLI) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts + # with minor constraints on them, so adding/removing params doesn't + # invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) @@ -40,9 +42,12 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: """Checks that the cli config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI + ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts + # with minor constraints on them, so adding/removing params doesn't + # invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 84381c4a6b..3e9abdbb90 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -31,17 +31,23 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_ENV_CLASSES = { - ScriptEnv # ScriptEnv is ABCMeta abstract, but there's no good way to test that dynamically in Python. + # ScriptEnv is ABCMeta abstract, but there's no good way to test that + # dynamically in Python. + ScriptEnv, } -expected_environment_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Environment, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_ENV_CLASSES] +expected_environment_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Environment, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_ENV_CLASSES +] assert expected_environment_class_names COMPOSITE_ENV_CLASS_NAME = CompositeEnv.__module__ + "." + CompositeEnv.__name__ -expected_leaf_environment_class_names = [subclass_name for subclass_name in expected_environment_class_names - if subclass_name != COMPOSITE_ENV_CLASS_NAME] +expected_leaf_environment_class_names = [ + subclass_name + for subclass_name in expected_environment_class_names + if subclass_name != COMPOSITE_ENV_CLASS_NAME +] # Do the full cross product of all the test cases and all the Environment types. @@ -55,11 +61,13 @@ def test_case_coverage_mlos_bench_environment_type(test_case_subtype: str, env_c if try_resolve_class_name(test_case.config.get("class")) == env_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}") + f"Missing test case for subtype {test_case_subtype} for Environment class {env_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_environment_configs_against_schema(test_case_name: str) -> None: """Checks that the environment config validates against the schema.""" @@ -72,5 +80,9 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: """Checks that the environment config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py index f5a5b83f9f..bcfc0aeb79 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/globals/test_globals_schemas.py @@ -23,12 +23,14 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_globals_configs_against_schema(test_case_name: str) -> None: """Checks that the CLI config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.GLOBALS) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts + # with minor constraints on them, so adding/removing params doesn't + # invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index 0c05cf7323..00ab6ab9d1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -31,12 +31,16 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_optimizer_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Optimizer, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_optimizer_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_optimizer_class_names -# Also make sure that we check for configs where the optimizer_type or space_adapter_type are left unspecified (None). +# Also make sure that we check for configs where the optimizer_type or +# space_adapter_type are left unspecified (None). expected_mlos_core_optimizer_types = list(OptimizerType) + [None] assert expected_mlos_core_optimizer_types @@ -48,7 +52,9 @@ # Do the full cross product of all the test cases and all the optimizer types. @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) -def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_bench_optimizer_type: str) -> None: +def test_case_coverage_mlos_bench_optimizer_type( + test_case_subtype: str, mlos_bench_optimizer_type: str +) -> None: """Checks to see if there is a given type of test case for the given mlos_bench optimizer type. """ @@ -56,7 +62,10 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_optimizer_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Optimizer class {mlos_bench_optimizer_type}") + f"Missing test case for subtype {test_case_subtype} " + f"for Optimizer class {mlos_bench_optimizer_type}" + ) + # Being a little lazy for the moment and relaxing the requirement that we have # a subtype test case for each optimizer and space adapter combo. @@ -65,47 +74,60 @@ def test_case_coverage_mlos_bench_optimizer_type(test_case_subtype: str, mlos_be @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) -def test_case_coverage_mlos_core_optimizer_type(test_case_type: str, - mlos_core_optimizer_type: Optional[OptimizerType]) -> None: +def test_case_coverage_mlos_core_optimizer_type( + test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] +) -> None: """Checks to see if there is a given type of test case for the given mlos_core optimizer type. """ optimizer_name = None if mlos_core_optimizer_type is None else mlos_core_optimizer_type.name for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): optimizer_type = None if test_case.config.get("config"): optimizer_type = test_case.config["config"].get("optimizer_type", None) if optimizer_type == optimizer_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for MlosCore Optimizer type {mlos_core_optimizer_type}") + f"Missing test case for type {test_case_type} " + f"for MlosCore Optimizer type {mlos_core_optimizer_type}" + ) @pytest.mark.parametrize("test_case_type", sorted(TEST_CASES.by_type)) # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) -def test_case_coverage_mlos_core_space_adapter_type(test_case_type: str, - mlos_core_space_adapter_type: Optional[SpaceAdapterType]) -> None: +def test_case_coverage_mlos_core_space_adapter_type( + test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] +) -> None: """Checks to see if there is a given type of test case for the given mlos_core space adapter type. """ - space_adapter_name = None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name + space_adapter_name = ( + None if mlos_core_space_adapter_type is None else mlos_core_space_adapter_type.name + ) for test_case in TEST_CASES.by_type[test_case_type].values(): - if try_resolve_class_name(test_case.config.get("class")) \ - == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer": + if ( + try_resolve_class_name(test_case.config.get("class")) + == "mlos_bench.optimizers.mlos_core_optimizer.MlosCoreOptimizer" + ): space_adapter_type = None if test_case.config.get("config"): space_adapter_type = test_case.config["config"].get("space_adapter_type", None) if space_adapter_type == space_adapter_name: return raise NotImplementedError( - f"Missing test case for type {test_case_type} for SpaceAdapter type {mlos_core_space_adapter_type}") + f"Missing test case for type {test_case_type} " + f"for SpaceAdapter type {mlos_core_space_adapter_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_optimizer_configs_against_schema(test_case_name: str) -> None: """Checks that the optimizer config validates against the schema.""" @@ -118,5 +140,9 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: """Checks that the optimizer config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 0908252971..8b29cfbd08 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -28,9 +28,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_scheduler_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Scheduler, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_scheduler_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_scheduler_class_names # Do the full cross product of all the test cases and all the scheduler types. @@ -38,7 +41,9 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) -def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_bench_scheduler_type: str) -> None: +def test_case_coverage_mlos_bench_scheduler_type( + test_case_subtype: str, mlos_bench_scheduler_type: str +) -> None: """Checks to see if there is a given type of test case for the given mlos_bench scheduler type. """ @@ -46,7 +51,10 @@ def test_case_coverage_mlos_bench_scheduler_type(test_case_subtype: str, mlos_be if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_scheduler_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Scheduler class {mlos_bench_scheduler_type}") + f"Missing test case for subtype {test_case_subtype} " + f"for Scheduler class {mlos_bench_scheduler_type}" + ) + # Now we actually perform all of those validation tests. @@ -63,8 +71,12 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: """Checks that the scheduler config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 92b8e69110..0f7e3ef7f2 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -36,16 +36,21 @@ # Dynamically enumerate some of the cases we want to make sure we cover. NON_CONFIG_SERVICE_CLASSES = { - ConfigPersistenceService, # configured thru the launcher cli args - TempDirContextService, # ABCMeta abstract class, but no good way to test that dynamically in Python. - AzureDeploymentService, # ABCMeta abstract base class - SshService, # ABCMeta abstract base class + # configured thru the launcher cli args + ConfigPersistenceService, + # ABCMeta abstract class, but no good way to test that dynamically in Python. + TempDirContextService, + # ABCMeta abstract base class + AzureDeploymentService, + # ABCMeta abstract base class + SshService, } -expected_service_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass - in get_all_concrete_subclasses(Service, pkg_name='mlos_bench') - if subclass not in NON_CONFIG_SERVICE_CLASSES] +expected_service_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses(Service, pkg_name="mlos_bench") + if subclass not in NON_CONFIG_SERVICE_CLASSES +] assert expected_service_class_names @@ -59,7 +64,7 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c for test_case in TEST_CASES.by_subtype[test_case_subtype].values(): config_list: List[Dict[str, Any]] if not isinstance(test_case.config, dict): - continue # type: ignore[unreachable] + continue # type: ignore[unreachable] if "class" not in test_case.config: config_list = test_case.config["services"] else: @@ -68,11 +73,13 @@ def test_case_coverage_mlos_bench_service_type(test_case_subtype: str, service_c if try_resolve_class_name(config.get("class")) == service_class: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for service class {service_class}") + f"Missing test case for subtype {test_case_subtype} for service class {service_class}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_service_configs_against_schema(test_case_name: str) -> None: """Checks that the service config validates against the schema.""" @@ -85,5 +92,9 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: """Checks that the service config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index fec23c8284..9d0d604c14 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -26,9 +26,12 @@ # Dynamically enumerate some of the cases we want to make sure we cover. -expected_mlos_bench_storage_class_names = [subclass.__module__ + "." + subclass.__name__ - for subclass in get_all_concrete_subclasses(Storage, # type: ignore[type-abstract] - pkg_name='mlos_bench')] +expected_mlos_bench_storage_class_names = [ + subclass.__module__ + "." + subclass.__name__ + for subclass in get_all_concrete_subclasses( + Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] + ) +] assert expected_mlos_bench_storage_class_names # Do the full cross product of all the test cases and all the storage types. @@ -36,7 +39,9 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) -def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_bench_storage_type: str) -> None: +def test_case_coverage_mlos_bench_storage_type( + test_case_subtype: str, mlos_bench_storage_type: str +) -> None: """Checks to see if there is a given type of test case for the given mlos_bench storage type. """ @@ -44,11 +49,14 @@ def test_case_coverage_mlos_bench_storage_type(test_case_subtype: str, mlos_benc if try_resolve_class_name(test_case.config.get("class")) == mlos_bench_storage_type: return raise NotImplementedError( - f"Missing test case for subtype {test_case_subtype} for Storage class {mlos_bench_storage_type}") + f"Missing test case for subtype {test_case_subtype} " + f"for Storage class {mlos_bench_storage_type}" + ) # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_storage_configs_against_schema(test_case_name: str) -> None: """Checks that the storage config validates against the schema.""" @@ -61,9 +69,15 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: """Checks that the storage config fails to validate if extra params are present in certain places. """ - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE) - check_test_case_config_with_extra_param(TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED) - - -if __name__ == '__main__': - pytest.main([__file__, '-n0'],) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE + ) + check_test_case_config_with_extra_param( + TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + ) + + +if __name__ == "__main__": + pytest.main( + [__file__, "-n0"], + ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py index 762314961e..f0694fc50f 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-params/test_tunable_params_schemas.py @@ -23,6 +23,7 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_params_configs_against_schema(test_case_name: str) -> None: """Checks that the tunable params config validates against the schema.""" diff --git a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py index 0426373a90..9b24a39d75 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/tunable-values/test_tunable_values_schemas.py @@ -23,12 +23,14 @@ # Now we actually perform all of those validation tests. + @pytest.mark.parametrize("test_case_name", sorted(TEST_CASES.by_path)) def test_tunable_values_configs_against_schema(test_case_name: str) -> None: """Checks that the tunable values config validates against the schema.""" check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.TUNABLE_VALUES) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. - # The trouble is that tunable-values, cli, globals all look like flat dicts with minor constraints on them, - # so adding/removing params doesn't invalidate it against all of the config types. + # The trouble is that tunable-values, cli, globals all look like flat dicts + # with minor constraints on them, so adding/removing params doesn't + # invalidate it against all of the config types. check_test_case_against_schema(TEST_CASES.by_path[test_case_name], ConfigSchema.UNIFIED) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index b5ac6380ed..5e9cb8ed13 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -23,19 +23,27 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: """If necessary, filter out json files that aren't for the module we're testing.""" + def predicate(config_path: str) -> bool: - arm_template = config_path.find("services/remote/azure/arm-templates/") >= 0 and config_path.endswith(".jsonc") + arm_template = config_path.find( + "services/remote/azure/arm-templates/" + ) >= 0 and config_path.endswith(".jsonc") setup_rg_scripts = config_path.find("azure/scripts/setup-rg") >= 0 return not (arm_template or setup_rg_scripts) + return [config_path for config_path in configs_to_filter if predicate(config_path)] -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_service_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_service_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) # Make an instance of the class based on the config. diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index ff2c8c6e5b..480b17425d 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -27,12 +27,16 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: return configs_to_filter -configs = locate_config_examples(ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs) +configs = locate_config_examples( + ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs +) assert configs @pytest.mark.parametrize("config_path", configs) -def test_load_storage_config_examples(config_loader_service: ConfigPersistenceService, config_path: str) -> None: +def test_load_storage_config_examples( + config_loader_service: ConfigPersistenceService, config_path: str +) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) assert isinstance(config, dict) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 2fc5268c26..09b242ac12 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -38,7 +38,7 @@ def mock_env(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -53,7 +53,7 @@ def mock_env_no_noise(tunable_groups: TunableGroups) -> MockEnv: "mock_env_range": [60, 120], "mock_env_metrics": ["score", "other_score"], }, - tunables=tunable_groups + tunables=tunable_groups, ) @@ -97,7 +97,9 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") -def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterProcessReaderWriterLock: +def docker_services_lock( + shared_temp_dir: str, short_testrun_uid: str +) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker services. @@ -107,7 +109,9 @@ def docker_services_lock(shared_temp_dir: str, short_testrun_uid: str) -> InterP A lock to ensure that setup/teardown operations don't happen while a worker is using the docker services. """ - return InterProcessReaderWriterLock(f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock") + return InterProcessReaderWriterLock( + f"{shared_temp_dir}/pytest_docker_services-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") @@ -120,7 +124,9 @@ def docker_setup_teardown_lock(shared_temp_dir: str, short_testrun_uid: str) -> ------ A lock to ensure that only one worker is doing setup/teardown at a time. """ - return InterProcessLock(f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock") + return InterProcessLock( + f"{shared_temp_dir}/pytest_docker_services-setup-teardown-{short_testrun_uid}.lock" + ) @pytest.fixture(scope="session") diff --git a/mlos_bench/mlos_bench/tests/environments/__init__.py b/mlos_bench/mlos_bench/tests/environments/__init__.py index 667a31d69d..01155de0b2 100644 --- a/mlos_bench/mlos_bench/tests/environments/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/__init__.py @@ -14,11 +14,13 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def check_env_success(env: Environment, - tunable_groups: TunableGroups, - expected_results: Dict[str, TunableValue], - expected_telemetry: List[Tuple[datetime, str, Any]], - global_config: Optional[dict] = None) -> None: +def check_env_success( + env: Environment, + tunable_groups: TunableGroups, + expected_results: Dict[str, TunableValue], + expected_telemetry: List[Tuple[datetime, str, Any]], + global_config: Optional[dict] = None, +) -> None: """ Set up an environment and run a test experiment there. @@ -48,7 +50,7 @@ def check_env_success(env: Environment, assert telemetry == pytest.approx(expected_telemetry, nan_ok=True) env_context.teardown() - assert not env_context._is_ready # pylint: disable=protected-access + assert not env_context._is_ready # pylint: disable=protected-access def check_env_fail_telemetry(env: Environment, tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index 52bea41524..e7e17e6df7 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -24,9 +24,13 @@ def test_expand_groups() -> None: """Check the dollar variable expansion for tunable groups.""" - assert Environment._expand_groups( - ["begin", "$list", "$empty", "$str", "end"], - _GROUPS) == ["begin", "c", "d", "efg", "end"] + assert Environment._expand_groups(["begin", "$list", "$empty", "$str", "end"], _GROUPS) == [ + "begin", + "c", + "d", + "efg", + "end", + ] def test_expand_groups_empty_input() -> None: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py index c6c6fff78f..0d81ec7847 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py @@ -36,26 +36,26 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "name": "Env 3 :: tmp_other_3", "class": "mlos_bench.environments.mock_env.MockEnv", "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], - } + }, ] }, tunables=tunable_groups, service=LocalExecService( - config={ - "temp_dir": "_test_tmp_global" - }, - parent=ConfigPersistenceService({ - "config_path": [ - path_join(os.path.dirname(__file__), "../config", abs_path=True), - ] - }) - ) + config={"temp_dir": "_test_tmp_global"}, + parent=ConfigPersistenceService( + { + "config_path": [ + path_join(os.path.dirname(__file__), "../config", abs_path=True), + ] + } + ), + ), ) def test_composite_services(composite_env: CompositeEnv) -> None: """Check that each environment gets its own instance of the services.""" - for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): + for i, path in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): service = composite_env.children[i]._service # pylint: disable=protected-access assert service is not None and hasattr(service, "temp_dir_context") with service.temp_dir_context() as temp_dir: diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 0f2669e85a..77a6bf5ad4 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -24,7 +24,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", "someConst": "root", - "global_param": "default" + "global_param": "default", }, "children": [ { @@ -39,7 +39,7 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "someConst", "global_param"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Server Environment 2", @@ -49,12 +49,12 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vmName": "$vm_server_name", "EnvId": 2, - "global_param": "local" + "global_param": "local", }, "required_args": ["vmName"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, { "name": "Mock Control Environment 3", @@ -68,15 +68,13 @@ def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "vm_server_name", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } - } - ] + }, + }, + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -88,59 +86,61 @@ def test_composite_env_params(composite_env: CompositeEnv) -> None: NOTE: The current logic is that variables flow down via required_args and const_args, parent """ assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", - "vm_server_name": "Mock Server VM" + "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config } def test_composite_env_setup(composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: """Check that the child environments update their tunable parameters.""" - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with composite_env as env_context: assert env_context.setup(tunable_groups) assert composite_env.children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args - "global_param": "global_value" # pulled in from the global_config + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[1].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert composite_env.children[2].parameters == { - "vmName": "Mock Control VM", # const_args from the parent - "EnvId": 3, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Control VM", # const_args from the parent + "EnvId": 3, # const_args from the child + "idle": "mwait", # tunable_params from the parent "vm_client_name": "Mock Client VM", "vm_server_name": "Mock Server VM", # "global_param": "global_value" # not required, so not picked from the global_config @@ -157,7 +157,7 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "const_args": { "vm_server_name": "Mock Server VM", "vm_client_name": "Mock Client VM", - "someConst": "root" + "someConst": "root", }, "children": [ { @@ -185,11 +185,11 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "EnvId", "someConst", "vm_server_name", - "global_param" + "global_param", ], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], @@ -214,20 +214,17 @@ def nested_composite_env(tunable_groups: TunableGroups) -> CompositeEnv: "required_args": ["vmName", "EnvId", "vm_client_name"], "mock_env_range": [60, 120], "mock_env_metrics": ["score"], - } + }, }, # ... ], }, }, - - ] + ], }, tunables=tunable_groups, service=ConfigPersistenceService({}), - global_config={ - "global_param": "global_value" - } + global_config={"global_param": "global_value"}, ) @@ -240,50 +237,54 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None """ assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent thru variable substitution - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B4ms", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent thru variable substitution + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B4ms", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "halt", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "halt", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", # "global_param": "global_value" # not required, so not picked from the global_config } -def test_nested_composite_env_setup(nested_composite_env: CompositeEnv, tunable_groups: TunableGroups) -> None: +def test_nested_composite_env_setup( + nested_composite_env: CompositeEnv, tunable_groups: TunableGroups +) -> None: """Check that the child environments update their tunable parameters.""" - tunable_groups.assign({ - "vmSize": "Standard_B2s", - "idle": "mwait", - "kernel_sched_migration_cost_ns": 100000, - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 100000, + } + ) with nested_composite_env as env_context: assert env_context.setup(tunable_groups) assert isinstance(nested_composite_env.children[0], CompositeEnv) assert nested_composite_env.children[0].children[0].parameters == { - "vmName": "Mock Client VM", # const_args from the parent - "EnvId": 1, # const_args from the child - "vmSize": "Standard_B2s", # tunable_params from the parent - "someConst": "root", # pulled in from parent via required_args + "vmName": "Mock Client VM", # const_args from the parent + "EnvId": 1, # const_args from the child + "vmSize": "Standard_B2s", # tunable_params from the parent + "someConst": "root", # pulled in from parent via required_args "vm_server_name": "Mock Server VM", - "global_param": "global_value" # pulled in from the global_config + "global_param": "global_value", # pulled in from the global_config } assert isinstance(nested_composite_env.children[1], CompositeEnv) assert nested_composite_env.children[1].children[0].parameters == { - "vmName": "Mock Server VM", # const_args from the parent - "EnvId": 2, # const_args from the child - "idle": "mwait", # tunable_params from the parent + "vmName": "Mock Server VM", # const_args from the parent + "EnvId": 2, # const_args from the child + "idle": "mwait", # tunable_params from the parent # "someConst": "root" # not required, so not passed from the parent "vm_client_name": "Mock Client VM", } diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index 0450dfa44d..a3df4cb558 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -12,9 +12,7 @@ def test_one_group(tunable_groups: TunableGroups) -> None: """Make sure only one tunable group is available to the environment.""" env = MockEnv( - name="Test Env", - config={"tunable_params": ["provision"]}, - tunables=tunable_groups + name="Test Env", config={"tunable_params": ["provision"]}, tunables=tunable_groups ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -26,7 +24,7 @@ def test_two_groups(tunable_groups: TunableGroups) -> None: env = MockEnv( name="Test Env", config={"tunable_params": ["provision", "kernel"]}, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", @@ -48,7 +46,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) expected_params = { "vmSize": "Standard_B4ms", @@ -71,11 +69,7 @@ def test_two_groups_setup(tunable_groups: TunableGroups) -> None: def test_zero_groups_implicit(tunable_groups: TunableGroups) -> None: """Make sure that no tunable groups are available to the environment by default.""" - env = MockEnv( - name="Test Env", - config={}, - tunables=tunable_groups - ) + env = MockEnv(name="Test Env", config={}, tunables=tunable_groups) assert env.tunable_params.get_param_values() == {} @@ -83,11 +77,7 @@ def test_zero_groups_explicit(tunable_groups: TunableGroups) -> None: """Make sure that no tunable groups are available to the environment when explicitly specifying an empty list of tunable_params. """ - env = MockEnv( - name="Test Env", - config={"tunable_params": []}, - tunables=tunable_groups - ) + env = MockEnv(name="Test Env", config={"tunable_params": []}, tunables=tunable_groups) assert env.tunable_params.get_param_values() == {} @@ -103,7 +93,7 @@ def test_zero_groups_implicit_setup(tunable_groups: TunableGroups) -> None: "const_param2": "foo", }, }, - tunables=tunable_groups + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == {} @@ -125,9 +115,7 @@ def test_loader_level_include() -> None: env_json = { "class": "mlos_bench.environments.mock_env.MockEnv", "name": "Test Env", - "include_tunables": [ - "environments/os/linux/boot/linux-boot-tunables.jsonc" - ], + "include_tunables": ["environments/os/linux/boot/linux-boot-tunables.jsonc"], "config": { "tunable_params": ["linux-kernel-boot"], "const_args": { @@ -136,12 +124,14 @@ def test_loader_level_include() -> None: }, }, } - loader = ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - }) + loader = ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) env = loader.build_environment(config=env_json, tunables=TunableGroups()) expected_params = { "align_va_addr": "on", diff --git a/mlos_bench/mlos_bench/tests/environments/local/__init__.py b/mlos_bench/mlos_bench/tests/environments/local/__init__.py index 4ef31ec299..d0a954f6fc 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/environments/local/__init__.py @@ -33,14 +33,20 @@ def create_local_env(tunable_groups: TunableGroups, config: Dict[str, Any]) -> L env : LocalEnv A new instance of the local environment. """ - return LocalEnv(name="TestLocalEnv", config=config, tunables=tunable_groups, - service=LocalExecService(parent=ConfigPersistenceService())) + return LocalEnv( + name="TestLocalEnv", + config=config, + tunables=tunable_groups, + service=LocalExecService(parent=ConfigPersistenceService()), + ) -def create_composite_local_env(tunable_groups: TunableGroups, - global_config: Dict[str, Any], - params: Dict[str, Any], - local_configs: List[Dict[str, Any]]) -> CompositeEnv: +def create_composite_local_env( + tunable_groups: TunableGroups, + global_config: Dict[str, Any], + params: Dict[str, Any], + local_configs: List[Dict[str, Any]], +) -> CompositeEnv: """ Create a CompositeEnv with several LocalEnv instances. @@ -71,7 +77,7 @@ def create_composite_local_env(tunable_groups: TunableGroups, "config": config, } for (i, config) in enumerate(local_configs) - ] + ], }, tunables=tunable_groups, global_config=global_config, diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index f8a8271a7f..1f3cf66110 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -42,7 +42,7 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - (var_prefix, var_suffix) = ("%", "%") if sys.platform == 'win32' else ("$", "") + (var_prefix, var_suffix) = ("%", "%") if sys.platform == "win32" else ("$", "") env = create_composite_local_env( tunable_groups=tunable_groups, @@ -66,8 +66,8 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo "required_args": ["errors", "reads"], "shell_env_params": [ "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads" # const_args overridden by the global config + "errors", # Comes from the parent const_args + "reads", # const_args overridden by the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -89,9 +89,9 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes" # Comes straight from the global config + "throughput", # const_args overridden by the composite env + "score", # Comes from the local const_args + "writes", # Comes straight from the global config ], "run": [ "echo 'metric,value' > output.csv", @@ -105,12 +105,13 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo ], "read_results_file": "output.csv", "read_telemetry_file": "telemetry.csv", - } - ] + }, + ], ) check_env_success( - env, tunable_groups, + env, + tunable_groups, expected_results={ "latency": 4.2, "throughput": 768.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py index fcdd9b1eab..684e7e13f6 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_stdout_test.py @@ -13,19 +13,23 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: """Print benchmark results to stdout and capture them in the LocalEnv.""" - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", - ], - "results_stdout_pattern": r"(\w+),([0-9.]+)", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", + ], + "results_stdout_pattern": r"(\w+),([0-9.]+)", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -39,19 +43,23 @@ def test_local_env_stdout(tunable_groups: TunableGroups) -> None: def test_local_env_stdout_anchored(tunable_groups: TunableGroups) -> None: """Print benchmark results to stdout and capture them in the LocalEnv.""" - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'Benchmark results:'", # This line should be ignored - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern - ], - "results_stdout_pattern": r"^(\w+),([0-9.]+)$", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'Benchmark results:'", # This line should be ignored + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'a,0,b,1'", # This line should be ignored in the case of anchored pattern + ], + "results_stdout_pattern": r"^(\w+),([0-9.]+)$", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, @@ -66,24 +74,28 @@ def test_local_env_file_stdout(tunable_groups: TunableGroups) -> None: """Print benchmark results to *BOTH* stdout and a file and extract the results from both. """ - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,111'", - "echo 'throughput,222'", - "echo 'score,0.999'", - "echo 'stdout-msg,string'", - "echo '-------------------'", # Should be ignored - "echo 'metric,value' > output.csv", - "echo 'extra1,333' >> output.csv", - "echo 'extra2,444' >> output.csv", - "echo 'file-msg,string' >> output.csv", - ], - "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,111'", + "echo 'throughput,222'", + "echo 'score,0.999'", + "echo 'stdout-msg,string'", + "echo '-------------------'", # Should be ignored + "echo 'metric,value' > output.csv", + "echo 'extra1,333' >> output.csv", + "echo 'extra2,444' >> output.csv", + "echo 'file-msg,string' >> output.csv", + ], + "results_stdout_pattern": r"([a-zA-Z0-9_-]+),([a-z0-9.]+)", + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 111.0, "throughput": 222.0, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 6fb2718706..7f3b070109 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -33,25 +33,29 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,4.1' >> output.csv", - "echo 'throughput,512' >> output.csv", - "echo 'score,0.95' >> output.csv", - "echo '-------------------'", # This output does not go anywhere - "echo 'timestamp,metric,value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_results_file": "output.csv", - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,4.1' >> output.csv", + "echo 'throughput,512' >> output.csv", + "echo 'score,0.95' >> output.csv", + "echo '-------------------'", # This output does not go anywhere + "echo 'timestamp,metric,value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_results_file": "output.csv", + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 4.1, "throughput": 512.0, @@ -68,7 +72,9 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_no_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Read the telemetry data with no header.""" ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second @@ -78,18 +84,22 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - f"echo {time_str1},cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + f"echo {time_str1},cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={}, expected_telemetry=[ (ts1.astimezone(UTC), "cpu_load", 0.65), @@ -100,9 +110,16 @@ def test_local_env_telemetry_no_header(tunable_groups: TunableGroups, zone_info: ) -@pytest.mark.filterwarnings("ignore:.*(Could not infer format, so each element will be parsed individually, falling back to `dateutil`).*:UserWarning::0") # pylint: disable=line-too-long # noqa +@pytest.mark.filterwarnings( + ( + "ignore:.*(Could not infer format, so each element will be parsed individually, " + "falling back to `dateutil`).*:UserWarning::0" + ) +) # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_info: Optional[tzinfo]) -> None: +def test_local_env_telemetry_wrong_header( + tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Read the telemetry data with incorrect header.""" ts1 = datetime.now(zone_info) ts1 -= timedelta(microseconds=ts1.microsecond) # Round to a second @@ -112,17 +129,20 @@ def test_local_env_telemetry_wrong_header(tunable_groups: TunableGroups, zone_in time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: the data is correct, but the header has unexpected column names - "echo 'ts,metric_name,metric_value' > telemetry.csv", - f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", - f"echo {time_str1},mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: the data is correct, but the header has unexpected column names + "echo 'ts,metric_name,metric_value' > telemetry.csv", + f"echo {time_str1},cpu_load,0.65 >> telemetry.csv", + f"echo {time_str1},mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) @@ -138,31 +158,37 @@ def test_local_env_telemetry_invalid(tunable_groups: TunableGroups) -> None: time_str1 = ts1.strftime(format_str) time_str2 = ts2.strftime(format_str) - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: too many columns - f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", - f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", - f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", - f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: too many columns + f"echo {time_str1},EXTRA,cpu_load,0.65 > telemetry.csv", + f"echo {time_str1},EXTRA,mem_usage,10240 >> telemetry.csv", + f"echo {time_str2},EXTRA,cpu_load,0.8 >> telemetry.csv", + f"echo {time_str2},EXTRA,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) def test_local_env_telemetry_invalid_ts(tunable_groups: TunableGroups) -> None: """Fail when the telemetry data has wrong format.""" - local_env = create_local_env(tunable_groups, { - "run": [ - # Error: field 1 must be a timestamp - "echo 1,cpu_load,0.65 > telemetry.csv", - "echo 2,mem_usage,10240 >> telemetry.csv", - "echo 3,cpu_load,0.8 >> telemetry.csv", - "echo 4,mem_usage,20480 >> telemetry.csv", - ], - "read_telemetry_file": "telemetry.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # Error: field 1 must be a timestamp + "echo 1,cpu_load,0.65 > telemetry.csv", + "echo 2,mem_usage,10240 >> telemetry.csv", + "echo 3,cpu_load,0.8 >> telemetry.csv", + "echo 4,mem_usage,20480 >> telemetry.csv", + ], + "read_telemetry_file": "telemetry.csv", + }, + ) check_env_fail_telemetry(local_env, tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py index 5ba125c028..25eea76b17 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_test.py @@ -12,18 +12,22 @@ def test_local_env(tunable_groups: TunableGroups) -> None: """Produce benchmark and telemetry data in a local script and read it.""" - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'metric,value' > output.csv", - "echo 'latency,10' >> output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'metric,value' > output.csv", + "echo 'latency,10' >> output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10.0, "throughput": 66.0, @@ -37,9 +41,7 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: """Basic check that context support for Service mixins are handled when environment contexts are entered. """ - local_env = create_local_env(tunable_groups, { - "run": ["echo NA"] - }) + local_env = create_local_env(tunable_groups, {"run": ["echo NA"]}) # pylint: disable=protected-access assert local_env._service assert not local_env._service._in_context @@ -47,25 +49,28 @@ def test_local_env_service_context(tunable_groups: TunableGroups) -> None: with local_env as env_context: assert env_context._in_context assert local_env._service._in_context - assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) + assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) assert all(svc._in_context for svc in local_env._service._service_contexts) assert all(svc._in_context for svc in local_env._service._services) - assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) + assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) assert not local_env._service._service_contexts assert not any(svc._in_context for svc in local_env._service._services) def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: """Fail if the results are not in the expected format.""" - local_env = create_local_env(tunable_groups, { - "run": [ - # No header - "echo 'latency,10' > output.csv", - "echo 'throughput,66' >> output.csv", - "echo 'score,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + # No header + "echo 'latency,10' > output.csv", + "echo 'throughput,66' >> output.csv", + "echo 'score,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) with local_env as env_context: assert env_context.setup(tunable_groups) @@ -75,16 +80,20 @@ def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: def test_local_env_wide(tunable_groups: TunableGroups) -> None: """Produce benchmark data in wide format and read it.""" - local_env = create_local_env(tunable_groups, { - "run": [ - "echo 'latency,throughput,score' > output.csv", - "echo '10,66,0.9' >> output.csv", - ], - "read_results_file": "output.csv", - }) + local_env = create_local_env( + tunable_groups, + { + "run": [ + "echo 'latency,throughput,score' > output.csv", + "echo '10,66,0.9' >> output.csv", + ], + "read_results_file": "output.csv", + }, + ) check_env_success( - local_env, tunable_groups, + local_env, + tunable_groups, expected_results={ "latency": 10, "throughput": 66, diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py index 16fd53959c..ef90155f0f 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_vars_test.py @@ -14,42 +14,45 @@ def _run_local_env(tunable_groups: TunableGroups, shell_subcmd: str, expected: dict) -> None: """Check that LocalEnv can set shell environment variables.""" - local_env = create_local_env(tunable_groups, { - "const_args": { - "const_arg": 111, # Passed into "shell_env_params" - "other_arg": 222, # NOT passed into "shell_env_params" + local_env = create_local_env( + tunable_groups, + { + "const_args": { + "const_arg": 111, # Passed into "shell_env_params" + "other_arg": 222, # NOT passed into "shell_env_params" + }, + "tunable_params": ["kernel"], + "shell_env_params": [ + "const_arg", # From "const_arg" + "kernel_sched_latency_ns", # From "tunable_params" + ], + "run": [ + "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", + f"echo {shell_subcmd} >> output.csv", + ], + "read_results_file": "output.csv", }, - "tunable_params": ["kernel"], - "shell_env_params": [ - "const_arg", # From "const_arg" - "kernel_sched_latency_ns", # From "tunable_params" - ], - "run": [ - "echo const_arg,other_arg,unknown_arg,kernel_sched_latency_ns > output.csv", - f"echo {shell_subcmd} >> output.csv", - ], - "read_results_file": "output.csv", - }) + ) check_env_success(local_env, tunable_groups, expected, []) -@pytest.mark.skipif(sys.platform == 'win32', reason="sh-like shell only") +@pytest.mark.skipif(sys.platform == "win32", reason="sh-like shell only") def test_local_env_vars_shell(tunable_groups: TunableGroups) -> None: """Check that LocalEnv can set shell environment variables in sh-like shell.""" _run_local_env( tunable_groups, shell_subcmd="$const_arg,$other_arg,$unknown_arg,$kernel_sched_latency_ns", expected={ - "const_arg": 111, # From "const_args" - "other_arg": float("NaN"), # Not included in "shell_env_params" - "unknown_arg": float("NaN"), # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": float("NaN"), # Not included in "shell_env_params" + "unknown_arg": float("NaN"), # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) -@pytest.mark.skipif(sys.platform != 'win32', reason="Windows only") +@pytest.mark.skipif(sys.platform != "win32", reason="Windows only") def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: """Check that LocalEnv can set shell environment variables on Windows / cmd shell. @@ -58,9 +61,9 @@ def test_local_env_vars_windows(tunable_groups: TunableGroups) -> None: tunable_groups, shell_subcmd=r"%const_arg%,%other_arg%,%unknown_arg%,%kernel_sched_latency_ns%", expected={ - "const_arg": 111, # From "const_args" - "other_arg": r"%other_arg%", # Not included in "shell_env_params" - "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable - "kernel_sched_latency_ns": 2000000, # From "tunable_params" - } + "const_arg": 111, # From "const_args" + "other_arg": r"%other_arg%", # Not included in "shell_env_params" + "unknown_arg": r"%unknown_arg%", # Unknown/undefined variable + "kernel_sched_latency_ns": 2000000, # From "tunable_params" + }, ) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 8f703c1d01..08ce0790bc 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -21,25 +21,26 @@ def mock_fileshare_service() -> MockFileShareService: """Create a new mock FileShareService instance.""" return MockFileShareService( config={"fileShareName": "MOCK_FILESHARE"}, - parent=LocalExecService(parent=ConfigPersistenceService()) + parent=LocalExecService(parent=ConfigPersistenceService()), ) @pytest.fixture -def local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService) -> LocalFileShareEnv: +def local_fileshare_env( + tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService +) -> LocalFileShareEnv: """Create a LocalFileShareEnv instance.""" env = LocalFileShareEnv( name="TestLocalFileShareEnv", config={ "const_args": { "experiment_id": "EXP_ID", # Passed into "shell_env_params" - "trial_id": 222, # NOT passed into "shell_env_params" + "trial_id": 222, # NOT passed into "shell_env_params" }, "tunable_params": ["boot"], "shell_env_params": [ - "trial_id", # From "const_arg" - "idle", # From "tunable_params", == "halt" + "trial_id", # From "const_arg" + "idle", # From "tunable_params", == "halt" ], "upload": [ { @@ -51,9 +52,7 @@ def local_fileshare_env(tunable_groups: TunableGroups, "to": "$experiment_id/$trial_id/input/data_$idle.csv", }, ], - "run": [ - "echo No-op run" - ], + "run": ["echo No-op run"], "download": [ { "from": "$experiment_id/$trial_id/$idle/data.csv", @@ -67,9 +66,11 @@ def local_fileshare_env(tunable_groups: TunableGroups, return env -def test_local_fileshare_env(tunable_groups: TunableGroups, - mock_fileshare_service: MockFileShareService, - local_fileshare_env: LocalFileShareEnv) -> None: +def test_local_fileshare_env( + tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, + local_fileshare_env: LocalFileShareEnv, +) -> None: """Test that the LocalFileShareEnv correctly expands the `$VAR` variables in the upload and download sections of the config. """ diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index b055f4f6aa..b29f1098d7 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -36,20 +36,22 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr assert data["score"] == pytest.approx(75.0, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 66.4), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 74.06), -]) -def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, + 66.4, + ), + ( + {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, + 74.06, + ), + ], +) +def test_mock_env_assign( + mock_env: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float +) -> None: """Check the benchmark values of the mock environment after the assignment.""" with mock_env as env_context: tunable_groups.assign(tunable_values) @@ -60,21 +62,25 @@ def test_mock_env_assign(mock_env: MockEnv, tunable_groups: TunableGroups, assert data["score"] == pytest.approx(expected_score, 0.01) -@pytest.mark.parametrize(('tunable_values', 'expected_score'), [ - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 250000 - }, 67.5), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 40000 - }, 75.1), -]) -def test_mock_env_no_noise_assign(mock_env_no_noise: MockEnv, - tunable_groups: TunableGroups, - tunable_values: dict, expected_score: float) -> None: +@pytest.mark.parametrize( + ("tunable_values", "expected_score"), + [ + ( + {"vmSize": "Standard_B2ms", "idle": "halt", "kernel_sched_migration_cost_ns": 250000}, + 67.5, + ), + ( + {"vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 40000}, + 75.1, + ), + ], +) +def test_mock_env_no_noise_assign( + mock_env_no_noise: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, + expected_score: float, +) -> None: """Check the benchmark values of the noiseless mock environment after the assignment. """ diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index 6fb9dba8c4..ecefc05cdd 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -34,25 +34,31 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: "ssh_priv_key_path": ssh_test_server.id_rsa_path, } - service = ConfigPersistenceService(config={"config_path": [str(files("mlos_bench.tests.config"))]}) + service = ConfigPersistenceService( + config={"config_path": [str(files("mlos_bench.tests.config"))]} + ) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") - env = service.load_environment(config_path, TunableGroups(), global_config=global_config, service=service) + env = service.load_environment( + config_path, TunableGroups(), global_config=global_config, service=service + ) check_env_success( - env, env.tunable_params, + env, + env.tunable_params, expected_results={ "hostname": ssh_test_server.service_name, "username": ssh_test_server.username, "score": 0.9, - "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" + "ssh_priv_key_path": np.nan, # empty strings are returned as "not a number" "test_param": "unset", "FOO": "unset", "ssh_username": "unset", }, expected_telemetry=[], ) - assert not os.path.exists(os.path.join(os.getcwd(), "output-downloaded.csv")), \ - "output-downloaded.csv should have been cleaned up by temp_dir context" + assert not os.path.exists( + os.path.join(os.getcwd(), "output-downloaded.csv") + ), "output-downloaded.csv should have been cleaned up by temp_dir context" if __name__ == "__main__": diff --git a/mlos_bench/mlos_bench/tests/event_loop_context_test.py b/mlos_bench/mlos_bench/tests/event_loop_context_test.py index eee8f53304..eb92c4c132 100644 --- a/mlos_bench/mlos_bench/tests/event_loop_context_test.py +++ b/mlos_bench/mlos_bench/tests/event_loop_context_test.py @@ -39,16 +39,21 @@ def __enter__(self) -> None: self.EVENT_LOOP_CONTEXT.enter() self._in_context = True - def __exit__(self, ex_type: Optional[Type[BaseException]], - ex_val: Optional[BaseException], - ex_tb: Optional[TracebackType]) -> Literal[False]: + def __exit__( + self, + ex_type: Optional[Type[BaseException]], + ex_val: Optional[BaseException], + ex_tb: Optional[TracebackType], + ) -> Literal[False]: assert self._in_context self.EVENT_LOOP_CONTEXT.exit() self._in_context = False return False -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_event_loop_context() -> None: """Test event loop context background thread setup/cleanup handling.""" # pylint: disable=protected-access,too-many-statements @@ -67,7 +72,9 @@ def test_event_loop_context() -> None: # After we enter the instance context, we should have a background thread. with event_loop_caller_instance_1: assert event_loop_caller_instance_1._in_context - assert isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] + assert ( # type: ignore[unreachable] + isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread, Thread) + ) # Give the thread a chance to start. # Mostly important on the underpowered Windows CI machines. time.sleep(0.25) @@ -86,12 +93,16 @@ def test_event_loop_context() -> None: assert event_loop_caller_instance_1._in_context assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 2 # We should only get one thread for all instances. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop_thread is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop_thread - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop \ - is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop \ + ) + assert ( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop + is event_loop_caller_instance_1.EVENT_LOOP_CONTEXT._event_loop is event_loop_caller_instance_2.EVENT_LOOP_CONTEXT._event_loop + ) assert not event_loop_caller_instance_2._in_context @@ -103,30 +114,40 @@ def test_event_loop_context() -> None: assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 # Once we exit the last context, the background thread should be stopped # and unusable for running co-routines. - assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None # type: ignore[unreachable] # (false positives) + assert ( # type: ignore[unreachable] # (false positives) + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread is None + ) assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop_thread_refcnt == 0 assert EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop is event_loop is not None assert not EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop.is_running() # Check that the event loop has no more tasks. - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_ready') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_ready") # Windows ProactorEventLoopPolicy adds a dummy task. - if sys.platform == 'win32' and isinstance(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop): + if sys.platform == "win32" and isinstance( + EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, asyncio.ProactorEventLoop + ): assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 1 else: assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._ready) == 0 - assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, '_scheduled') + assert hasattr(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop, "_scheduled") assert len(EventLoopContextCaller.EVENT_LOOP_CONTEXT._event_loop._scheduled) == 0 - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) raise ValueError(f"Future should not have been available to wait on {future.result()}") # Test that when re-entering the context we have the same event loop. @@ -137,12 +158,14 @@ def test_event_loop_context() -> None: # Test running again. start = time.time() - future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine(asyncio.sleep(0.1, result='foo')) + future = event_loop_caller_instance_1.EVENT_LOOP_CONTEXT.run_coroutine( + asyncio.sleep(0.1, result="foo") + ) assert 0.0 <= time.time() - start < 0.1 - assert future.result(timeout=0.2) == 'foo' + assert future.result(timeout=0.2) == "foo" assert 0.1 <= time.time() - start <= 0.2 -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_event_loop_context"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py index 04750b7c2a..6fe340c9eb 100644 --- a/mlos_bench/mlos_bench/tests/launcher_in_process_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_in_process_test.py @@ -12,19 +12,33 @@ @pytest.mark.parametrize( - ("argv", "expected_score"), [ - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", - "--trial_config_repeat_count", "5", - "--mock_env_seed", "-1", # Deterministic Mock Environment. - ], 67.40329), - ([ - "--config", "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", - "--trial_config_repeat_count", "3", - "--max_suggestions", "3", - "--mock_env_seed", "42", # Noisy Mock Environment. - ], 64.53897), - ] + ("argv", "expected_score"), + [ + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-bench.jsonc", + "--trial_config_repeat_count", + "5", + "--mock_env_seed", + "-1", # Deterministic Mock Environment. + ], + 67.40329, + ), + ( + [ + "--config", + "mlos_bench/mlos_bench/tests/config/cli/mock-opt.jsonc", + "--trial_config_repeat_count", + "3", + "--max_suggestions", + "3", + "--mock_env_seed", + "42", # Noisy Mock Environment. + ], + 64.53897, + ), + ], ) def test_main_bench(argv: List[str], expected_score: float) -> None: """Run mlos_bench optimization loop with given config and check the results.""" diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index 687436d316..f577f21526 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -48,8 +48,8 @@ def config_paths() -> List[str]: """ return [ path_join(os.getcwd(), abs_path=True), - str(files('mlos_bench.config')), - str(files('mlos_bench.tests.config')), + str(files("mlos_bench.config")), + str(files("mlos_bench.tests.config")), ] @@ -65,20 +65,23 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] + environ["USER"] = environ["USERNAME"] # This is part of the minimal required args by the Launcher. - env_conf_path = 'environments/mock/mock_env.jsonc' - cli_args = '--config-paths ' + ' '.join(config_paths) + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - ' --scheduler schedulers/sync_scheduler.jsonc' + \ - f' --environment {env_conf_path}' + \ - ' --globals globals/global_test_config.jsonc' + \ - ' --globals globals/global_test_extra_config.jsonc' \ - ' --test_global_value_2 from-args' + env_conf_path = "environments/mock/mock_env.jsonc" + cli_args = ( + "--config-paths " + + " ".join(config_paths) + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + " --scheduler schedulers/sync_scheduler.jsonc" + + f" --environment {env_conf_path}" + + " --globals globals/global_test_config.jsonc" + + " --globals globals/global_test_extra_config.jsonc" + " --test_global_value_2 from-args" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -86,27 +89,28 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsLocalExec) assert isinstance(launcher.service, SupportsRemoteExec) # Check that the first --globals file is loaded and $var expansion is handled. - assert launcher.global_config['experiment_id'] == 'MockExperiment' - assert launcher.global_config['testVmName'] == 'MockExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockExperiment" + assert launcher.global_config["testVmName"] == "MockExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockExperiment-vm-vnet" # Check that the second --globals file is loaded. - assert launcher.global_config['test_global_value'] == 'from-file' + assert launcher.global_config["test_global_value"] == "from-file" # Check overriding values in a file from the command line. - assert launcher.global_config['test_global_value_2'] == 'from-args' + assert launcher.global_config["test_global_value_2"] == "from-args" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( + os.getcwd(), "foo", abs_path=True + ) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert launcher.teardown # Check that the environment that got loaded looks to be of the right type. env_config = launcher.config_loader.load_config(env_conf_path, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, OneShotOptimizer) # Check that the optimizer got initialized with defaults. assert launcher.optimizer.tunable_params.is_defaults() - assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer + assert launcher.optimizer.max_iterations == 1 # value for OneShotOptimizer # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access @@ -122,23 +126,25 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # variable so we use a separate variable. # See global_test_config.jsonc for more details. environ["CUSTOM_PATH_FROM_ENV"] = os.getcwd() - if sys.platform == 'win32': + if sys.platform == "win32": # Some env tweaks for platform compatibility. - environ['USER'] = environ['USERNAME'] - - config_file = 'cli/test-cli-config.jsonc' - globals_file = 'globals/global_test_config.jsonc' - cli_args = ' '.join([f"--config-path {config_path}" for config_path in config_paths]) + \ - f' --config {config_file}' + \ - ' --service services/remote/mock/mock_auth_service.jsonc' + \ - ' --service services/remote/mock/mock_remote_exec_service.jsonc' + \ - f' --globals {globals_file}' + \ - ' --experiment_id MockeryExperiment' + \ - ' --no-teardown' + \ - ' --random-init' + \ - ' --random-seed 1234' + \ - ' --trial-config-repeat-count 5' + \ - ' --max_trials 200' + environ["USER"] = environ["USERNAME"] + + config_file = "cli/test-cli-config.jsonc" + globals_file = "globals/global_test_config.jsonc" + cli_args = ( + " ".join([f"--config-path {config_path}" for config_path in config_paths]) + + f" --config {config_file}" + + " --service services/remote/mock/mock_auth_service.jsonc" + + " --service services/remote/mock/mock_remote_exec_service.jsonc" + + f" --globals {globals_file}" + + " --experiment_id MockeryExperiment" + + " --no-teardown" + + " --random-init" + + " --random-seed 1234" + + " --trial-config-repeat-count 5" + + " --max_trials 200" + ) launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -148,35 +154,42 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert isinstance(launcher.service, SupportsRemoteExec) # Check that the --globals file is loaded and $var expansion is handled # using the value provided on the CLI. - assert launcher.global_config['experiment_id'] == 'MockeryExperiment' - assert launcher.global_config['testVmName'] == 'MockeryExperiment-vm' + assert launcher.global_config["experiment_id"] == "MockeryExperiment" + assert launcher.global_config["testVmName"] == "MockeryExperiment-vm" # Check that secondary expansion also works. - assert launcher.global_config['testVnetName'] == 'MockeryExperiment-vm-vnet' + assert launcher.global_config["testVnetName"] == "MockeryExperiment-vm-vnet" # Check that we can expand a $var in a config file that references an environment variable. - assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) \ - == path_join(os.getcwd(), "foo", abs_path=True) - assert launcher.global_config["varWithEnvVarRef"] == f'user:{getuser()}' + assert path_join(launcher.global_config["pathVarWithEnvVarRef"], abs_path=True) == path_join( + os.getcwd(), "foo", abs_path=True + ) + assert launcher.global_config["varWithEnvVarRef"] == f"user:{getuser()}" assert not launcher.teardown config = launcher.config_loader.load_config(config_file, ConfigSchema.CLI) - assert launcher.config_loader.config_paths == [path_join(path, abs_path=True) for path in config_paths + config['config_path']] + assert launcher.config_loader.config_paths == [ + path_join(path, abs_path=True) for path in config_paths + config["config_path"] + ] # Check that the environment that got loaded looks to be of the right type. - env_config_file = config['environment'] + env_config_file = config["environment"] env_config = launcher.config_loader.load_config(env_config_file, ConfigSchema.ENVIRONMENT) - assert check_class_name(launcher.environment, env_config['class']) + assert check_class_name(launcher.environment, env_config["class"]) # Check that the optimizer looks right. assert isinstance(launcher.optimizer, MlosCoreOptimizer) - opt_config_file = config['optimizer'] + opt_config_file = config["optimizer"] opt_config = launcher.config_loader.load_config(opt_config_file, ConfigSchema.OPTIMIZER) globals_file_config = launcher.config_loader.load_config(globals_file, ConfigSchema.GLOBALS) # The actual global_config gets overwritten as a part of processing, so to test # this we read the original value out of the source files. - orig_max_iters = globals_file_config.get('max_suggestions', opt_config.get('config', {}).get('max_suggestions', 100)) - assert launcher.optimizer.max_iterations \ - == orig_max_iters \ - == launcher.global_config['max_suggestions'] + orig_max_iters = globals_file_config.get( + "max_suggestions", opt_config.get("config", {}).get("max_suggestions", 100) + ) + assert ( + launcher.optimizer.max_iterations + == orig_max_iters + == launcher.global_config["max_suggestions"] + ) # Check that the optimizer got initialized with random values instead of the defaults. # Note: the environment doesn't get updated until suggest() is called to @@ -193,12 +206,12 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. - assert config['random_seed'] == 42 + assert config["random_seed"] == 42 # TODO: This isn't actually respected yet because the `--random-init` only # applies to a temporary Optimizer used to populate the initial values via # random sampling. # assert launcher.optimizer.seed == 1234 -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__, "-n1"]) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index 04aad14faf..d6f5b8cfd5 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -25,16 +25,21 @@ def root_path() -> str: @pytest.fixture def local_exec_service() -> LocalExecService: """Test fixture for LocalExecService.""" - return LocalExecService(parent=ConfigPersistenceService({ - "config_path": [ - "mlos_bench/config", - "mlos_bench/examples", - ] - })) + return LocalExecService( + parent=ConfigPersistenceService( + { + "config_path": [ + "mlos_bench/config", + "mlos_bench/examples", + ] + } + ) + ) -def _launch_main_app(root_path: str, local_exec_service: LocalExecService, - cli_config: str, re_expected: List[str]) -> None: +def _launch_main_app( + root_path: str, local_exec_service: LocalExecService, cli_config: str, re_expected: List[str] +) -> None: """Run mlos_bench command-line application with given config and check the results in the log. """ @@ -45,10 +50,13 @@ def _launch_main_app(root_path: str, local_exec_service: LocalExecService, # temp_dir = '/tmp' log_path = path_join(temp_dir, "mock-test.log") (return_code, _stdout, _stderr) = local_exec_service.local_exec( - ["./mlos_bench/mlos_bench/run.py" + - " --config_path ./mlos_bench/mlos_bench/tests/config/" + - f" {cli_config} --log_file '{log_path}'"], - cwd=root_path) + [ + "./mlos_bench/mlos_bench/run.py" + + " --config_path ./mlos_bench/mlos_bench/tests/config/" + + f" {cli_config} --log_file '{log_path}'" + ], + cwd=root_path, + ) assert return_code == 0 try: @@ -71,32 +79,33 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.40\d+\}\s*$", + ], ) def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService) -> None: + root_path: str, local_exec_service: LocalExecService +) -> None: """Run mlos_bench command-line application with mock benchmark config and user- specified tunable values and check the results in the log. """ _launch_main_app( - root_path, local_exec_service, - " --config cli/mock-bench.jsonc" + - " --tunable_values tunable-values/tunable-values-example.jsonc" + - " --trial_config_repeat_count 5" + - " --mock_env_seed -1", # Deterministic Mock Environment. + root_path, + local_exec_service, + " --config cli/mock-bench.jsonc" + + " --tunable_values tunable-values/tunable-values-example.jsonc" + + " --trial_config_repeat_count 5" + + " --mock_env_seed -1", # Deterministic Mock Environment. [ - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 67\.11\d+\}\s*$", + ], ) @@ -105,23 +114,23 @@ def test_launch_main_app_opt(root_path: str, local_exec_service: LocalExecServic the results in the log. """ _launch_main_app( - root_path, local_exec_service, - "--config cli/mock-opt.jsonc" + - " --trial_config_repeat_count 3" + - " --max_suggestions 3" + - " --mock_env_seed 42", # Noisy Mock Environment. + root_path, + local_exec_service, + "--config cli/mock-opt.jsonc" + + " --trial_config_repeat_count 3" + + " --max_suggestions 3" + + " --mock_env_seed 42", # Noisy Mock Environment. [ # Iteration 1: Expect first value to be the baseline - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': 64\.53\d+\}$", # Iteration 2: The result may not always be deterministic - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Iteration 3: non-deterministic (depends on the optimizer) - f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + - r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", + f"^{_RE_DATE} mlos_core_optimizer\\.py:\\d+ " + + r"bulk_register DEBUG Warm-up END: .* :: \{'score': \d+\.\d+\}$", # Final result: baseline is the optimum for the mock environment - f"^{_RE_DATE} run\\.py:\\d+ " + - r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", - ] + f"^{_RE_DATE} run\\.py:\\d+ " + r"_main INFO Final score: \{'score': 64\.53\d+\}\s*$", + ], ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/conftest.py b/mlos_bench/mlos_bench/tests/optimizers/conftest.py index 810f4fcc0e..6b660f7fea 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/conftest.py +++ b/mlos_bench/mlos_bench/tests/optimizers/conftest.py @@ -19,29 +19,29 @@ def mock_configs() -> List[dict]: """Mock configurations of earlier experiments.""" return [ { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_latency_ns': 1000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_latency_ns": 1000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'halt', - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_latency_ns': 2000000, + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_latency_ns": 2000000, }, { - 'vmSize': 'Standard_B4ms', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': -1, # Special value - 'kernel_sched_latency_ns': 3000000, + "vmSize": "Standard_B4ms", + "idle": "mwait", + "kernel_sched_migration_cost_ns": -1, # Special value + "kernel_sched_latency_ns": 3000000, }, { - 'vmSize': 'Standard_B2s', - 'idle': 'mwait', - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_latency_ns': 4000000, - } + "vmSize": "Standard_B2s", + "idle": "mwait", + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_latency_ns": 4000000, + }, ] @@ -55,7 +55,7 @@ def mock_opt_no_defaults(tunable_groups: TunableGroups) -> MockOptimizer: "optimization_targets": {"score": "min"}, "max_suggestions": 5, "start_with_defaults": False, - "seed": SEED + "seed": SEED, }, ) @@ -66,11 +66,7 @@ def mock_opt(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={ - "optimization_targets": {"score": "min"}, - "max_suggestions": 5, - "seed": SEED - }, + config={"optimization_targets": {"score": "min"}, "max_suggestions": 5, "seed": SEED}, ) @@ -80,11 +76,7 @@ def mock_opt_max(tunable_groups: TunableGroups) -> MockOptimizer: return MockOptimizer( tunables=tunable_groups, service=None, - config={ - "optimization_targets": {"score": "max"}, - "max_suggestions": 10, - "seed": SEED - }, + config={"optimization_targets": {"score": "max"}, "max_suggestions": 10, "seed": SEED}, ) diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 077b2ed058..8761201c8e 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -48,15 +48,23 @@ def grid_search_tunables_config() -> dict: @pytest.fixture -def grid_search_tunables_grid(grid_search_tunables: TunableGroups) -> List[Dict[str, TunableValue]]: +def grid_search_tunables_grid( + grid_search_tunables: TunableGroups, +) -> List[Dict[str, TunableValue]]: """ Test fixture for grid from tunable groups. Used to check that the grids are the same (ignoring order). """ - tunables_params_values = [tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None] - tunable_names = tuple(tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None) - return list(dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values)) + tunables_params_values = [ + tunable.values for tunable, _group in grid_search_tunables if tunable.values is not None + ] + tunable_names = tuple( + tunable.name for tunable, _group in grid_search_tunables if tunable.values is not None + ) + return list( + dict(zip(tunable_names, combo)) for combo in itertools.product(*tunables_params_values) + ) @pytest.fixture @@ -66,22 +74,28 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture -def grid_search_opt(grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> GridSearchOptimizer: +def grid_search_opt( + grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]] +) -> GridSearchOptimizer: """Test fixture for grid search optimizer.""" assert len(grid_search_tunables) == 3 # Test the convergence logic by controlling the number of iterations to be not a # multiple of the number of elements in the grid. max_iterations = len(grid_search_tunables_grid) * 2 - 3 - return GridSearchOptimizer(tunables=grid_search_tunables, config={ - "max_suggestions": max_iterations, - "optimization_targets": {"score": "max", "other_score": "min"}, - }) + return GridSearchOptimizer( + tunables=grid_search_tunables, + config={ + "max_suggestions": max_iterations, + "optimization_targets": {"score": "max", "other_score": "min"}, + }, + ) -def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search_grid( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """Make sure that grid search optimizer initializes and works correctly.""" # Check the size. expected_grid_size = math.prod(tunable.cardinality for tunable, _group in grid_search_tunables) @@ -106,9 +120,11 @@ def test_grid_search_grid(grid_search_opt: GridSearchOptimizer, # assert grid_search_opt.pending_configs == grid_search_tunables_grid -def test_grid_search(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups, - grid_search_tunables_grid: List[Dict[str, TunableValue]]) -> None: +def test_grid_search( + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], +) -> None: """Make sure that grid search optimizer initializes and works correctly.""" score: Dict[str, TunableValue] = {"score": 1.0, "other_score": 2.0} status = Status.SUCCEEDED @@ -133,7 +149,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables_grid.remove(default_config) assert default_config not in grid_search_opt.pending_configs assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid + ) # The next suggestion should be a different element in the grid search. suggestion = grid_search_opt.suggest() @@ -147,7 +165,9 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, grid_search_tunables_grid.remove(suggestion.get_param_values()) assert all(config in grid_search_tunables_grid for config in grid_search_opt.pending_configs) - assert all(config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid) + assert all( + config in list(grid_search_opt.pending_configs) for config in grid_search_tunables_grid + ) # We consider not_converged as either having reached "max_suggestions" or an empty grid? @@ -161,7 +181,8 @@ def test_grid_search(grid_search_opt: GridSearchOptimizer, assert not list(grid_search_opt.suggested_configs) assert not grid_search_opt.not_converged() - # But if we still have iterations left, we should be able to suggest again by refilling the grid. + # But if we still have iterations left, we should be able to suggest again by + # refilling the grid. assert grid_search_opt.current_iteration < grid_search_opt.max_iterations assert grid_search_opt.suggest() assert list(grid_search_opt.pending_configs) @@ -212,7 +233,7 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: assert best_suggestion_dict not in grid_search_opt.suggested_configs best_suggestion_score: Dict[str, TunableValue] = {} - for (opt_target, opt_dir) in grid_search_opt.targets.items(): + for opt_target, opt_dir in grid_search_opt.targets.items(): val = score[opt_target] assert isinstance(val, (int, float)) best_suggestion_score[opt_target] = val - 1 if opt_dir == "min" else val + 1 @@ -226,34 +247,52 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: # Check bulk register suggested = [grid_search_opt.suggest() for _ in range(suggest_count)] - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() in grid_search_opt.suggested_configs for suggestion in suggested) + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() in grid_search_opt.suggested_configs + for suggestion in suggested + ) # Those new suggestions also shouldn't be in the set of previously suggested configs. assert all(suggestion.get_param_values() not in suggested_shuffled for suggestion in suggested) - grid_search_opt.bulk_register([suggestion.get_param_values() for suggestion in suggested], - [score] * len(suggested), - [status] * len(suggested)) - - assert all(suggestion.get_param_values() not in grid_search_opt.pending_configs for suggestion in suggested) - assert all(suggestion.get_param_values() not in grid_search_opt.suggested_configs for suggestion in suggested) + grid_search_opt.bulk_register( + [suggestion.get_param_values() for suggestion in suggested], + [score] * len(suggested), + [status] * len(suggested), + ) + + assert all( + suggestion.get_param_values() not in grid_search_opt.pending_configs + for suggestion in suggested + ) + assert all( + suggestion.get_param_values() not in grid_search_opt.suggested_configs + for suggestion in suggested + ) best_score, best_config = grid_search_opt.get_best_observation() assert best_score == best_suggestion_score assert best_config == best_suggestion -def test_grid_search_register(grid_search_opt: GridSearchOptimizer, - grid_search_tunables: TunableGroups) -> None: +def test_grid_search_register( + grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups +) -> None: """Make sure that the `.register()` method adjusts the score signs correctly.""" assert grid_search_opt.register( - grid_search_tunables, Status.SUCCEEDED, { + grid_search_tunables, + Status.SUCCEEDED, + { "score": 1.0, "other_score": 2.0, - }) == { - "score": -1.0, # max - "other_score": 2.0, # min + }, + ) == { + "score": -1.0, # max + "other_score": 2.0, # min } assert grid_search_opt.register(grid_search_tunables, Status.FAILED) == { diff --git a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py index 0a0add5b24..4494cba3ef 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/llamatune_opt_test.py @@ -30,7 +30,8 @@ def llamatune_opt(tunable_groups: TunableGroups) -> MlosCoreOptimizer: "optimizer_type": "SMAC", "seed": SEED, # "start_with_defaults": False, - }) + }, + ) @pytest.fixture @@ -53,6 +54,6 @@ def test_llamatune_optimizer(llamatune_opt: MlosCoreOptimizer, mock_scores: list assert best_score["score"] == pytest.approx(66.66, 0.01) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-n1", "-k", "test_llamatune_optimizer", __file__]) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py index 47768f87a4..043e457375 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_df_test.py @@ -20,9 +20,9 @@ def mlos_core_optimizer(tunable_groups: TunableGroups) -> MlosCoreOptimizer: """An instance of a mlos_core optimizer (FLAML-based).""" test_opt_config = { - 'optimizer_type': 'FLAML', - 'max_suggestions': 10, - 'seed': SEED, + "optimizer_type": "FLAML", + "max_suggestions": 10, + "seed": SEED, } return MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -33,44 +33,44 @@ def test_df(mlos_core_optimizer: MlosCoreOptimizer, mock_configs: List[dict]) -> assert isinstance(df_config, pandas.DataFrame) assert df_config.shape == (4, 6) assert set(df_config.columns) == { - 'kernel_sched_latency_ns', - 'kernel_sched_migration_cost_ns', - 'kernel_sched_migration_cost_ns!type', - 'kernel_sched_migration_cost_ns!special', - 'idle', - 'vmSize', + "kernel_sched_latency_ns", + "kernel_sched_migration_cost_ns", + "kernel_sched_migration_cost_ns!type", + "kernel_sched_migration_cost_ns!special", + "idle", + "vmSize", } - assert df_config.to_dict(orient='records') == [ + assert df_config.to_dict(orient="records") == [ { - 'idle': 'halt', - 'kernel_sched_latency_ns': 1000000, - 'kernel_sched_migration_cost_ns': 50000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 1000000, + "kernel_sched_migration_cost_ns": 50000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'halt', - 'kernel_sched_latency_ns': 2000000, - 'kernel_sched_migration_cost_ns': 40000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B4ms', + "idle": "halt", + "kernel_sched_latency_ns": 2000000, + "kernel_sched_migration_cost_ns": 40000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 3000000, - 'kernel_sched_migration_cost_ns': None, # The value is special! - 'kernel_sched_migration_cost_ns!special': -1, - 'kernel_sched_migration_cost_ns!type': 'special', - 'vmSize': 'Standard_B4ms', + "idle": "mwait", + "kernel_sched_latency_ns": 3000000, + "kernel_sched_migration_cost_ns": None, # The value is special! + "kernel_sched_migration_cost_ns!special": -1, + "kernel_sched_migration_cost_ns!type": "special", + "vmSize": "Standard_B4ms", }, { - 'idle': 'mwait', - 'kernel_sched_latency_ns': 4000000, - 'kernel_sched_migration_cost_ns': 200000, - 'kernel_sched_migration_cost_ns!special': None, - 'kernel_sched_migration_cost_ns!type': 'range', - 'vmSize': 'Standard_B2s', + "idle": "mwait", + "kernel_sched_latency_ns": 4000000, + "kernel_sched_migration_cost_ns": 200000, + "kernel_sched_migration_cost_ns!special": None, + "kernel_sched_migration_cost_ns!type": "range", + "vmSize": "Standard_B2s", }, ] diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index c97c97cf0d..3b45d7dcd6 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -15,17 +15,17 @@ from mlos_bench.util import path_join from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer -_OUTPUT_DIR_PATH_BASE = r'c:/temp' if sys.platform == 'win32' else '/tmp/' -_OUTPUT_DIR = '_test_output_dir' # Will be deleted after the test. +_OUTPUT_DIR_PATH_BASE = r"c:/temp" if sys.platform == "win32" else "/tmp/" +_OUTPUT_DIR = "_test_output_dir" # Will be deleted after the test. def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) -> None: """Test invalid max_trials initialization of mlos_core SMAC optimizer.""" test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_trials': 10, - 'max_suggestions': 11, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_trials": 10, + "max_suggestions": 11, + "seed": SEED, } with pytest.raises(AssertionError): opt = MlosCoreOptimizer(tunable_groups, test_opt_config) @@ -35,14 +35,14 @@ def test_init_mlos_core_smac_opt_bad_trial_count(tunable_groups: TunableGroups) def test_init_mlos_core_smac_opt_max_trials(tunable_groups: TunableGroups) -> None: """Test max_trials initialization of mlos_core SMAC optimizer.""" test_opt_config = { - 'optimizer_type': 'SMAC', - 'max_suggestions': 123, - 'seed': SEED, + "optimizer_type": "SMAC", + "max_suggestions": 123, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) - assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config['max_suggestions'] + assert opt._opt.base_optimizer.scenario.n_trials == test_opt_config["max_suggestions"] def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGroups) -> None: @@ -51,9 +51,9 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr """ output_dir = path_join(_OUTPUT_DIR_PATH_BASE, _OUTPUT_DIR) test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': output_dir, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": output_dir, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) @@ -61,7 +61,8 @@ def test_init_mlos_core_smac_absolute_output_directory(tunable_groups: TunableGr assert isinstance(opt._opt, SmacOptimizer) # Final portions of the path are generated by SMAC when run_name is not specified. assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - str(test_opt_config['output_directory'])) + str(test_opt_config["output_directory"]) + ) shutil.rmtree(output_dir) @@ -70,65 +71,76 @@ def test_init_mlos_core_smac_relative_output_directory(tunable_groups: TunableGr optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']))) + path_join(os.getcwd(), str(test_opt_config["output_directory"])) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_run_name(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_run_name( + tunable_groups: TunableGroups, +) -> None: """Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'run_name': 'test_run', - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "run_name": "test_run", + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), str(test_opt_config['run_name']))) + path_join( + os.getcwd(), str(test_opt_config["output_directory"]), str(test_opt_config["run_name"]) + ) + ) shutil.rmtree(_OUTPUT_DIR) -def test_init_mlos_core_smac_relative_output_directory_with_experiment_id(tunable_groups: TunableGroups) -> None: +def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( + tunable_groups: TunableGroups, +) -> None: """Test relative path output directory initialization of mlos_core SMAC optimizer. """ test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': _OUTPUT_DIR, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": _OUTPUT_DIR, + "seed": SEED, } global_config = { - 'experiment_id': 'experiment_id', + "experiment_id": "experiment_id", } opt = MlosCoreOptimizer(tunable_groups, test_opt_config, global_config) assert isinstance(opt, MlosCoreOptimizer) # pylint: disable=protected-access assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( - path_join(os.getcwd(), str(test_opt_config['output_directory']), global_config['experiment_id'])) + path_join( + os.getcwd(), str(test_opt_config["output_directory"]), global_config["experiment_id"] + ) + ) shutil.rmtree(_OUTPUT_DIR) def test_init_mlos_core_smac_temp_output_directory(tunable_groups: TunableGroups) -> None: """Test random output directory initialization of mlos_core SMAC optimizer.""" test_opt_config = { - 'optimizer_type': 'SMAC', - 'output_directory': None, - 'seed': SEED, + "optimizer_type": "SMAC", + "output_directory": None, + "seed": SEED, } opt = MlosCoreOptimizer(tunable_groups, test_opt_config) assert isinstance(opt, MlosCoreOptimizer) diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index 1ce5903306..ee41f95b13 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -16,24 +16,33 @@ def mock_configurations_no_defaults() -> list: """A list of 2-tuples of (tunable_values, score) to test the optimizers.""" return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 13112, - "kernel_sched_latency_ns": 796233790, - }, 88.88), - ({ - "vmSize": "Standard_B2ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 117026, - "kernel_sched_latency_ns": 149827706, - }, 66.66), - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": 354785, - "kernel_sched_latency_ns": 795285932, - }, 99.99), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 13112, + "kernel_sched_latency_ns": 796233790, + }, + 88.88, + ), + ( + { + "vmSize": "Standard_B2ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 117026, + "kernel_sched_latency_ns": 149827706, + }, + 66.66, + ), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": 354785, + "kernel_sched_latency_ns": 795285932, + }, + 99.99, + ), ] @@ -41,18 +50,21 @@ def mock_configurations_no_defaults() -> list: def mock_configurations(mock_configurations_no_defaults: list) -> list: """A list of 2-tuples of (tunable_values, score) to test the optimizers.""" return [ - ({ - "vmSize": "Standard_B4ms", - "idle": "halt", - "kernel_sched_migration_cost_ns": -1, - "kernel_sched_latency_ns": 2000000, - }, 88.88), + ( + { + "vmSize": "Standard_B4ms", + "idle": "halt", + "kernel_sched_migration_cost_ns": -1, + "kernel_sched_latency_ns": 2000000, + }, + 88.88, + ), ] + mock_configurations_no_defaults def _optimize(mock_opt: MockOptimizer, mock_configurations: list) -> float: """Run several iterations of the optimizer and return the best score.""" - for (tunable_values, score) in mock_configurations: + for tunable_values, score in mock_configurations: assert mock_opt.not_converged() tunables = mock_opt.suggest() assert tunables.get_param_values() == tunable_values @@ -70,8 +82,9 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N assert score == pytest.approx(66.66, 0.01) -def test_mock_optimizer_no_defaults(mock_opt_no_defaults: MockOptimizer, - mock_configurations_no_defaults: list) -> None: +def test_mock_optimizer_no_defaults( + mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list +) -> None: """Make sure that mock optimizer produces consistent suggestions.""" score = _optimize(mock_opt_no_defaults, mock_configurations_no_defaults) assert score == pytest.approx(66.66, 0.01) diff --git a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py index dd832ce348..cbbd2a627d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/opt_bulk_register_test.py @@ -24,10 +24,7 @@ def mock_configs_str(mock_configs: List[dict]) -> List[dict]: (This can happen when we retrieve the data from storage). """ - return [ - {key: str(val) for (key, val) in config.items()} - for config in mock_configs - ] + return [{key: str(val) for (key, val) in config.items()} for config in mock_configs] @pytest.fixture @@ -47,10 +44,12 @@ def mock_status() -> List[Status]: return [Status.FAILED, Status.SUCCEEDED, Status.SUCCEEDED, Status.SUCCEEDED] -def _test_opt_update_min(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_min( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """Test the bulk update of the optimizer on the minimization problem.""" opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() @@ -61,14 +60,16 @@ def _test_opt_update_min(opt: Optimizer, "vmSize": "Standard_B4ms", "idle": "mwait", "kernel_sched_migration_cost_ns": -1, - 'kernel_sched_latency_ns': 3000000, + "kernel_sched_latency_ns": 3000000, } -def _test_opt_update_max(opt: Optimizer, - configs: List[dict], - scores: List[Optional[Dict[str, TunableValue]]], - status: Optional[List[Status]] = None) -> None: +def _test_opt_update_max( + opt: Optimizer, + configs: List[dict], + scores: List[Optional[Dict[str, TunableValue]]], + status: Optional[List[Status]] = None, +) -> None: """Test the bulk update of the optimizer on the maximization problem.""" opt.bulk_register(configs, scores, status) (score, tunables) = opt.get_best_observation() @@ -79,14 +80,16 @@ def _test_opt_update_max(opt: Optimizer, "vmSize": "Standard_B2s", "idle": "mwait", "kernel_sched_migration_cost_ns": 200000, - 'kernel_sched_latency_ns': 4000000, + "kernel_sched_latency_ns": 4000000, } -def test_update_mock_min(mock_opt: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min( + mock_opt: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the mock optimizer on the minimization problem.""" _test_opt_update_min(mock_opt, mock_configs, mock_scores, mock_status) # make sure the first suggestion after bulk load is *NOT* the default config: @@ -94,53 +97,65 @@ def test_update_mock_min(mock_opt: MockOptimizer, "vmSize": "Standard_B4ms", "idle": "halt", "kernel_sched_migration_cost_ns": 13112, - 'kernel_sched_latency_ns': 796233790, + "kernel_sched_latency_ns": 796233790, } -def test_update_mock_min_str(mock_opt: MockOptimizer, - mock_configs_str: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_min_str( + mock_opt: MockOptimizer, + mock_configs_str: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the mock optimizer with all-strings data.""" _test_opt_update_min(mock_opt, mock_configs_str, mock_scores, mock_status) -def test_update_mock_max(mock_opt_max: MockOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_mock_max( + mock_opt_max: MockOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the mock optimizer on the maximization problem.""" _test_opt_update_max(mock_opt_max, mock_configs, mock_scores, mock_status) -def test_update_flaml(flaml_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml( + flaml_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the FLAML optimizer.""" _test_opt_update_min(flaml_opt, mock_configs, mock_scores, mock_status) -def test_update_flaml_max(flaml_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_flaml_max( + flaml_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the FLAML optimizer.""" _test_opt_update_max(flaml_opt_max, mock_configs, mock_scores, mock_status) -def test_update_smac(smac_opt: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac( + smac_opt: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the SMAC optimizer.""" _test_opt_update_min(smac_opt, mock_configs, mock_scores, mock_status) -def test_update_smac_max(smac_opt_max: MlosCoreOptimizer, - mock_configs: List[dict], - mock_scores: List[Optional[Dict[str, TunableValue]]], - mock_status: List[Status]) -> None: +def test_update_smac_max( + smac_opt_max: MlosCoreOptimizer, + mock_configs: List[dict], + mock_scores: List[Optional[Dict[str, TunableValue]]], + mock_status: List[Status], +) -> None: """Test the bulk update of the SMAC optimizer.""" _test_opt_update_max(smac_opt_max, mock_configs, mock_scores, mock_status) diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index c845f87549..1596d4997d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -52,7 +52,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: (status, _ts, output) = env_context.run() assert status.is_succeeded() assert output is not None - score = output['score'] + score = output["score"] assert isinstance(score, float) assert 60 <= score <= 120 logger("score: %s", str(score)) @@ -65,8 +65,7 @@ def _optimize(env: Environment, opt: Optimizer) -> Tuple[float, TunableGroups]: return (best_score["score"], best_tunables) -def test_mock_optimization_loop(mock_env_no_noise: MockEnv, - mock_opt: MockOptimizer) -> None: +def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimizer) -> None: """Toy optimization loop with mock environment and optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, mock_opt) assert score == pytest.approx(64.9, 0.01) @@ -78,8 +77,9 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, } -def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, - mock_opt_no_defaults: MockOptimizer) -> None: +def test_mock_optimization_loop_no_defaults( + mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer +) -> None: """Toy optimization loop with mock environment and optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, mock_opt_no_defaults) assert score == pytest.approx(60.97, 0.01) @@ -91,8 +91,7 @@ def test_mock_optimization_loop_no_defaults(mock_env_no_noise: MockEnv, } -def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, - flaml_opt: MlosCoreOptimizer) -> None: +def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, flaml_opt: MlosCoreOptimizer) -> None: """Toy optimization loop with mock environment and FLAML optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, flaml_opt) assert score == pytest.approx(60.15, 0.01) @@ -105,8 +104,7 @@ def test_flaml_optimization_loop(mock_env_no_noise: MockEnv, # @pytest.mark.skip(reason="SMAC is not deterministic") -def test_smac_optimization_loop(mock_env_no_noise: MockEnv, - smac_opt: MlosCoreOptimizer) -> None: +def test_smac_optimization_loop(mock_env_no_noise: MockEnv, smac_opt: MlosCoreOptimizer) -> None: """Toy optimization loop with mock environment and SMAC optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, smac_opt) expected_score = 70.33 diff --git a/mlos_bench/mlos_bench/tests/services/__init__.py b/mlos_bench/mlos_bench/tests/services/__init__.py index fa411976e6..a0b56eeb03 100644 --- a/mlos_bench/mlos_bench/tests/services/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/__init__.py @@ -12,8 +12,8 @@ from .remote import MockFileShareService, MockRemoteExecService, MockVMService __all__ = [ - 'MockLocalExecService', - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockLocalExecService", + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index 067715f7e4..3f8a6514ed 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -25,15 +25,19 @@ @pytest.fixture def config_persistence_service() -> ConfigPersistenceService: """Test fixture for ConfigPersistenceService.""" - return ConfigPersistenceService({ - "config_path": [ - "./non-existent-dir/test/foo/bar", # Non-existent config path - ".", # cwd - str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) - # Shouldn't be necessary since we automatically add this. - # str(files("mlos_bench.config").joinpath("")), # Stock configs - ] - }) + return ConfigPersistenceService( + { + "config_path": [ + "./non-existent-dir/test/foo/bar", # Non-existent config path + ".", # cwd + str( + files("mlos_bench.tests.config").joinpath("") + ), # Test configs (relative to mlos_bench/tests) + # Shouldn't be necessary since we automatically add this. + # str(files("mlos_bench.config").joinpath("")), # Stock configs + ] + } + ) def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: @@ -68,7 +72,7 @@ def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService assert os.path.exists(path) assert os.path.samefile( ConfigPersistenceService.BUILTIN_CONFIG_PATH, - os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) + os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]), ) @@ -92,8 +96,9 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No """Check if we can successfully load a config file located relative to `config_path`. """ - tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", - ConfigSchema.TUNABLE_VALUES) + tunables_data = config_persistence_service.load_config( + "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES + ) assert tunables_data is not None assert isinstance(tunables_data, dict) assert len(tunables_data) >= 1 diff --git a/mlos_bench/mlos_bench/tests/services/local/__init__.py b/mlos_bench/mlos_bench/tests/services/local/__init__.py index 01f6e04dcf..79778d3c25 100644 --- a/mlos_bench/mlos_bench/tests/services/local/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/__init__.py @@ -11,5 +11,5 @@ from .mock import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index e1da1105b0..c52e643025 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -50,11 +50,12 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: json.dump(params_meta, fh_meta) script_path = local_exec_service.config_loader_service.resolve_path( - "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py") + "environments/os/linux/runtime/scripts/local/generate_kernel_config_script.py" + ) - (return_code, _stdout, stderr) = local_exec_service.local_exec([ - f"{script_path} {input_file} {meta_file} {output_file}" - ], cwd=temp_dir, env=params) + (return_code, _stdout, stderr) = local_exec_service.local_exec( + [f"{script_path} {input_file} {meta_file} {output_file}"], cwd=temp_dir, env=params + ) assert stderr.strip() == "" assert return_code == 0 diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index 572c332282..6a64398fc3 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -20,25 +20,27 @@ def test_split_cmdline() -> None: """Test splitting a commandline into subcommands.""" - cmdline = ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" + cmdline = ( + ". env.sh && (echo hello && echo world | tee > /tmp/test || echo foo && echo $var; true)" + ) assert list(split_cmdline(cmdline)) == [ - ['.', 'env.sh'], - ['&&'], - ['('], - ['echo', 'hello'], - ['&&'], - ['echo', 'world'], - ['|'], - ['tee'], - ['>'], - ['/tmp/test'], - ['||'], - ['echo', 'foo'], - ['&&'], - ['echo', '$var'], - [';'], - ['true'], - [')'], + [".", "env.sh"], + ["&&"], + ["("], + ["echo", "hello"], + ["&&"], + ["echo", "world"], + ["|"], + ["tee"], + [">"], + ["/tmp/test"], + ["||"], + ["echo", "foo"], + ["&&"], + ["echo", "$var"], + [";"], + ["true"], + [")"], ] @@ -59,7 +61,10 @@ def test_resolve_script(local_exec_service: LocalExecService) -> None: expected_cmdline = f". env.sh && {script_abspath} --input foo" subcmds_tokens = split_cmdline(orig_cmdline) # pylint: disable=protected-access - subcmds_tokens = [local_exec_service._resolve_cmdline_script_path(subcmd_tokens) for subcmd_tokens in subcmds_tokens] + subcmds_tokens = [ + local_exec_service._resolve_cmdline_script_path(subcmd_tokens) + for subcmd_tokens in subcmds_tokens + ] cmdline_tokens = [token for subcmd_tokens in subcmds_tokens for token in subcmd_tokens] expanded_cmdline = " ".join(cmdline_tokens) assert expanded_cmdline == expected_cmdline @@ -77,10 +82,7 @@ def test_run_script(local_exec_service: LocalExecService) -> None: def test_run_script_multiline(local_exec_service: LocalExecService) -> None: """Run a multiline script locally and check the results.""" # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo hello", - "echo world" - ]) + (return_code, stdout, stderr) = local_exec_service.local_exec(["echo hello", "echo world"]) assert return_code == 0 assert stdout.strip().split() == ["hello", "world"] assert stderr.strip() == "" @@ -89,12 +91,12 @@ def test_run_script_multiline(local_exec_service: LocalExecService) -> None: def test_run_script_multiline_env(local_exec_service: LocalExecService) -> None: """Run a multiline script locally and pass the environment variables to it.""" # `echo` should work on all platforms - (return_code, stdout, stderr) = local_exec_service.local_exec([ - r"echo $var", # Unix shell - r"echo %var%" # Windows cmd - ], env={"var": "VALUE", "int_var": 10}) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [r"echo $var", r"echo %var%"], # Unix shell # Windows cmd + env={"var": "VALUE", "int_var": 10}, + ) assert return_code == 0 - if sys.platform == 'win32': + if sys.platform == "win32": assert stdout.strip().split() == ["$var", "VALUE"] else: assert stdout.strip().split() == ["VALUE", "%var%"] @@ -105,23 +107,26 @@ def test_run_script_read_csv(local_exec_service: LocalExecService) -> None: """Run a script locally and read the resulting CSV file.""" with local_exec_service.temp_dir_context() as temp_dir: - (return_code, stdout, stderr) = local_exec_service.local_exec([ - "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows - "echo '111,222' >> output.csv", - "echo '333,444' >> output.csv", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + "echo 'col1,col2'> output.csv", # No space before '>' to make it work on Windows + "echo '111,222' >> output.csv", + "echo '333,444' >> output.csv", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" assert stderr.strip() == "" data = pandas.read_csv(path_join(temp_dir, "output.csv")) - if sys.platform == 'win32': + if sys.platform == "win32": # Workaround for Python's subprocess module on Windows adding a # space inbetween the col1,col2 arg and the redirect symbol which # cmd poorly interprets as being part of the original string arg. # Without this, we get "col2 " as the second column name. - data.rename(str.rstrip, axis='columns', inplace=True) + data.rename(str.rstrip, axis="columns", inplace=True) assert all(data.col1 == [111, 333]) assert all(data.col2 == [222, 444]) @@ -134,10 +139,13 @@ def test_run_script_write_read_txt(local_exec_service: LocalExecService) -> None with open(path_join(temp_dir, input_file), "wt", encoding="utf-8") as fh_input: fh_input.write("hello\n") - (return_code, stdout, stderr) = local_exec_service.local_exec([ - f"echo 'world' >> {input_file}", - f"echo 'test' >> {input_file}", - ], cwd=temp_dir) + (return_code, stdout, stderr) = local_exec_service.local_exec( + [ + f"echo 'world' >> {input_file}", + f"echo 'test' >> {input_file}", + ], + cwd=temp_dir, + ) assert return_code == 0 assert stdout.strip() == "" @@ -156,11 +164,13 @@ def test_run_script_fail(local_exec_service: LocalExecService) -> None: def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> None: """Try to run a series of commands, one of which fails, and abort early.""" - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code != 0 assert stdout.strip() == "hello" @@ -168,11 +178,13 @@ def test_run_script_middle_fail_abort(local_exec_service: LocalExecService) -> N def test_run_script_middle_fail_pass(local_exec_service: LocalExecService) -> None: """Try to run a series of commands, one of which fails, but let it pass.""" local_exec_service.abort_on_error = False - (return_code, stdout, _stderr) = local_exec_service.local_exec([ - "echo hello", - "cmd /c 'exit 1'" if sys.platform == 'win32' else "false", - "echo world", - ]) + (return_code, stdout, _stderr) = local_exec_service.local_exec( + [ + "echo hello", + "cmd /c 'exit 1'" if sys.platform == "win32" else "false", + "echo world", + ] + ) assert return_code == 0 assert stdout.splitlines() == [ "hello", @@ -188,13 +200,17 @@ def test_temp_dir_path_expansion() -> None: # the fact. with tempfile.TemporaryDirectory() as temp_dir: global_config = { - "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" + "workdir": temp_dir, # e.g., "." or "/tmp/mlos_bench" } config = { # The temp_dir for the LocalExecService should get expanded via workdir global config. "temp_dir": "$workdir/temp", } - local_exec_service = LocalExecService(config, global_config, parent=ConfigPersistenceService()) + local_exec_service = LocalExecService( + config, global_config, parent=ConfigPersistenceService() + ) # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) - assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join(temp_dir, "temp", abs_path=True) + assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( + temp_dir, "temp", abs_path=True + ) diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py index 7e8035e6a0..2bae6d8dbd 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/__init__.py @@ -7,5 +7,5 @@ from .mock_local_exec_service import MockLocalExecService __all__ = [ - 'MockLocalExecService', + "MockLocalExecService", ] diff --git a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py index 9582cc62c8..39934c40e8 100644 --- a/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/local/mock/mock_local_exec_service.py @@ -31,16 +31,21 @@ class MockLocalExecService(TempDirContextService, SupportsLocalExec): """Mock methods for LocalExecService testing.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.local_exec]) + config, global_config, parent, self.merge_methods(methods, [self.local_exec]) ) - def local_exec(self, script_lines: Iterable[str], - env: Optional[Mapping[str, "TunableValue"]] = None, - cwd: Optional[str] = None) -> Tuple[int, str, str]: + def local_exec( + self, + script_lines: Iterable[str], + env: Optional[Mapping[str, "TunableValue"]] = None, + cwd: Optional[str] = None, + ) -> Tuple[int, str, str]: return (0, "", "") diff --git a/mlos_bench/mlos_bench/tests/services/mock_service.py b/mlos_bench/mlos_bench/tests/services/mock_service.py index e1fe7cbc5a..cebea96912 100644 --- a/mlos_bench/mlos_bench/tests/services/mock_service.py +++ b/mlos_bench/mlos_bench/tests/services/mock_service.py @@ -28,19 +28,24 @@ class MockServiceBase(Service, SupportsSomeMethod): """A base service class for testing.""" def __init__( - self, - config: Optional[dict] = None, - global_config: Optional[dict] = None, - parent: Optional[Service] = None, - methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None) -> None: + self, + config: Optional[dict] = None, + global_config: Optional[dict] = None, + parent: Optional[Service] = None, + methods: Optional[Union[Dict[str, Callable], List[Callable]]] = None, + ) -> None: super().__init__( config, global_config, parent, - self.merge_methods(methods, [ - self.some_method, - self.some_other_method, - ])) + self.merge_methods( + methods, + [ + self.some_method, + self.some_other_method, + ], + ), + ) def some_method(self) -> str: """some_method.""" diff --git a/mlos_bench/mlos_bench/tests/services/remote/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/__init__.py index 137ea2e888..b486afdb7c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/__init__.py @@ -13,7 +13,7 @@ from .mock.mock_vm_service import MockVMService __all__ = [ - 'MockFileShareService', - 'MockRemoteExecService', - 'MockVMService', + "MockFileShareService", + "MockRemoteExecService", + "MockVMService", ] diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 6d54389264..fa1adc9935 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -16,7 +16,9 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_file( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" @@ -24,8 +26,9 @@ def test_download_file(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fil local_path = f"{local_folder}/{filename}" mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, \ - patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client: + with patch.object(mock_share_client, "get_file_client") as mock_get_file_client, patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client: mock_get_directory_client.return_value = Mock(exists=Mock(return_value=False)) azure_fileshare.download(config, remote_path, local_path) @@ -45,38 +48,41 @@ def make_dir_client_returns(remote_folder: str) -> dict: return { remote_folder: Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_folder", "is_directory": True}, - {"name": "a_file_1.csv", "is_directory": False}, - ]) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_folder", "is_directory": True}, + {"name": "a_file_1.csv", "is_directory": False}, + ] + ), ), f"{remote_folder}/a_folder": Mock( exists=Mock(return_value=True), - list_directories_and_files=Mock(return_value=[ - {"name": "a_file_2.csv", "is_directory": False}, - ]) - ), - f"{remote_folder}/a_file_1.csv": Mock( - exists=Mock(return_value=False) - ), - f"{remote_folder}/a_folder/a_file_2.csv": Mock( - exists=Mock(return_value=False) + list_directories_and_files=Mock( + return_value=[ + {"name": "a_file_2.csv", "is_directory": False}, + ] + ), ), + f"{remote_folder}/a_file_1.csv": Mock(exists=Mock(return_value=False)), + f"{remote_folder}/a_folder/a_file_2.csv": Mock(exists=Mock(return_value=False)), } @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_non_recursive(mock_makedirs: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_non_recursive( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] @@ -85,47 +91,63 @@ def test_download_folder_non_recursive(mock_makedirs: MagicMock, mock_get_file_client.assert_called_with( f"{remote_folder}/a_file_1.csv", ) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - ], any_order=True) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") -def test_download_folder_recursive(mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_download_folder_recursive( + mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" dir_client_returns = make_dir_client_returns(remote_folder) - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} - with patch.object(mock_share_client, "get_directory_client") as mock_get_directory_client, \ - patch.object(mock_share_client, "get_file_client") as mock_get_file_client: + with patch.object( + mock_share_client, "get_directory_client" + ) as mock_get_directory_client, patch.object( + mock_share_client, "get_file_client" + ) as mock_get_file_client: mock_get_directory_client.side_effect = lambda x: dir_client_returns[x] azure_fileshare.download(config, remote_folder, local_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) - mock_get_directory_client.assert_has_calls([ - call(remote_folder), - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) + mock_get_directory_client.assert_has_calls( + [ + call(remote_folder), + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") -def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService) -> None: +def test_upload_file( + mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService +) -> None: filename = "test.csv" remote_folder = "a/remote/folder" local_folder = "some/local/folder" remote_path = f"{remote_folder}/{filename}" local_path = f"{local_folder}/{filename}" - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access mock_isdir.return_value = False config: dict = {} @@ -141,6 +163,7 @@ def test_upload_file(mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshar class MyDirEntry: # pylint: disable=too-few-public-methods """Dummy class for os.DirEntry.""" + def __init__(self, name: str, is_a_dir: bool): self.name = name self.is_a_dir = is_a_dir @@ -184,17 +207,19 @@ def process_paths(input_path: str) -> str: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_non_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_non_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: @@ -206,23 +231,28 @@ def test_upload_directory_non_recursive(mock_scandir: MagicMock, @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.scandir") -def test_upload_directory_recursive(mock_scandir: MagicMock, - mock_isdir: MagicMock, - mock_open: MagicMock, - azure_fileshare: AzureFileShareService) -> None: +def test_upload_directory_recursive( + mock_scandir: MagicMock, + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, +) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" scandir_returns = make_scandir_returns(local_folder) isdir_returns = make_isdir_returns(local_folder) mock_scandir.side_effect = lambda x: scandir_returns[process_paths(x)] mock_isdir.side_effect = lambda x: isdir_returns[process_paths(x)] - mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access + mock_share_client = azure_fileshare._share_client # pylint: disable=protected-access config: dict = {} with patch.object(mock_share_client, "get_file_client") as mock_get_file_client: azure_fileshare.upload(config, local_folder, remote_folder, recursive=True) - mock_get_file_client.assert_has_calls([ - call(f"{remote_folder}/a_file_1.csv"), - call(f"{remote_folder}/a_folder/a_file_2.csv"), - ], any_order=True) + mock_get_file_client.assert_has_calls( + [ + call(f"{remote_folder}/a_file_1.csv"), + call(f"{remote_folder}/a_folder/a_file_2.csv"), + ], + any_order=True, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py index 67fc9d56fb..87dd78fd5a 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_network_services_test.py @@ -16,23 +16,31 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_network_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_network_service: AzureNetworkService) -> None: +def test_wait_network_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_network_service: AzureNetworkService, +) -> None: """Test retries of the network deployment operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -45,30 +53,37 @@ def test_wait_network_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("deprovision_network", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), # These should succeed since we set ignore_errors=True by default (401, Status.SUCCEEDED), (404, Status.SUCCEEDED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_network_operation_status(mock_requests: MagicMock, - azure_network_service: AzureNetworkService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_network_operation_status( + mock_requests: MagicMock, + azure_network_service: AzureNetworkService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """Test network operation status.""" mock_response = MagicMock() mock_response.status_code = http_status_code @@ -83,20 +98,28 @@ def test_network_operation_status(mock_requests: MagicMock, @pytest.fixture -def test_azure_network_service_no_deployment_template(azure_auth_service: AzureAuthService) -> None: +def test_azure_network_service_no_deployment_template( + azure_auth_service: AzureAuthService, +) -> None: """Tests creating a network services without a deployment template (should fail).""" with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) with pytest.raises(ValueError): - _ = AzureNetworkService(config={ - # "deploymentTemplatePath": None, - "deploymentTemplateParameters": { - "location": "westus2", + _ = AzureNetworkService( + config={ + # "deploymentTemplatePath": None, + "deploymentTemplateParameters": { + "location": "westus2", + }, }, - }, parent=azure_auth_service) + parent=azure_auth_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 6b1235f3f7..33f25f48c8 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -17,23 +17,31 @@ @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_host_deployment_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_host_deployment_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """Test retries of the host deployment operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), ] @@ -46,14 +54,17 @@ def test_wait_host_deployment_retry(mock_getconn: MagicMock, "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", }, - is_setup=True) + is_setup=True, + ) assert status == operation_status def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAuthService) -> None: """Test expanding template params recursively.""" config = { - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "deploymentTemplatePath": ( + "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" + ), "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", "deploymentTemplateParameters": { @@ -69,15 +80,23 @@ def test_azure_vm_service_recursive_template_params(azure_auth_service: AzureAut } azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) assert azure_vm_service.deploy_params["location"] == global_config["location"] - assert azure_vm_service.deploy_params["vmMeta"] == f'{global_config["vmName"]}-{global_config["location"]}' - assert azure_vm_service.deploy_params["vmNsg"] == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + assert ( + azure_vm_service.deploy_params["vmMeta"] + == f'{global_config["vmName"]}-{global_config["location"]}' + ) + assert ( + azure_vm_service.deploy_params["vmNsg"] + == f'{azure_vm_service.deploy_params["vmMeta"]}-nsg' + ) def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> None: """Test loading custom data from a file.""" config = { "customDataFile": "services/remote/azure/cloud-init/alt-ssh.yml", - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", + "deploymentTemplatePath": ( + "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" + ), "subscription": "TEST_SUB1", "resourceGroup": "TEST_RG1", "deploymentTemplateParameters": { @@ -90,14 +109,15 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N } with pytest.raises(ValueError): config_with_custom_data = deepcopy(config) - config_with_custom_data['deploymentTemplateParameters']['customData'] = "DUMMY_CUSTOM_DATA" # type: ignore[index] + config_with_custom_data["deploymentTemplateParameters"]["customData"] = "DUMMY_CUSTOM_DATA" # type: ignore[index] # pylint: disable=line-too-long # noqa AzureVMService(config_with_custom_data, global_config, parent=azure_auth_service) azure_vm_service = AzureVMService(config, global_config, parent=azure_auth_service) - assert azure_vm_service.deploy_params['customData'] + assert azure_vm_service.deploy_params["customData"] @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("start_host", True), ("stop_host", True), ("shutdown", True), @@ -105,22 +125,27 @@ def test_azure_vm_service_custom_data(azure_auth_service: AzureAuthService) -> N ("deallocate_host", True), ("restart_host", True), ("reboot", True), - ]) + ], +) @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests") # pylint: disable=too-many-arguments -def test_vm_operation_status(mock_requests: MagicMock, - azure_vm_service: AzureVMService, - operation_name: str, - accepts_params: bool, - http_status_code: int, - operation_status: Status) -> None: +def test_vm_operation_status( + mock_requests: MagicMock, + azure_vm_service: AzureVMService, + operation_name: str, + accepts_params: bool, + http_status_code: int, + operation_status: Status, +) -> None: """Test VM operation status.""" mock_response = MagicMock() mock_response.status_code = http_status_code @@ -135,12 +160,14 @@ def test_vm_operation_status(mock_requests: MagicMock, @pytest.mark.parametrize( - ("operation_name", "accepts_params"), [ + ("operation_name", "accepts_params"), + [ ("provision_host", True), - ]) -def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, - operation_name: str, - accepts_params: bool) -> None: + ], +) +def test_vm_operation_invalid( + azure_vm_service_remote_exec_only: AzureVMService, operation_name: str, accepts_params: bool +) -> None: """Test VM operation status for an incomplete service config.""" operation = getattr(azure_vm_service_remote_exec_only, operation_name) with pytest.raises(ValueError): @@ -149,8 +176,9 @@ def test_vm_operation_invalid(azure_vm_service_remote_exec_only: AzureVMService, @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_ready( + mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService +) -> None: """Test waiting for the completion of the remote VM operation.""" # Mock response header async_url = "DUMMY_ASYNC_URL" @@ -169,21 +197,18 @@ def test_wait_vm_operation_ready(mock_session: MagicMock, mock_sleep: MagicMock, status, _ = azure_vm_service.wait_host_operation(params) - assert (async_url, ) == mock_session.return_value.get.call_args[0] - assert (retry_after, ) == mock_sleep.call_args[0] + assert (async_url,) == mock_session.return_value.get.call_args[0] + assert (retry_after,) == mock_sleep.call_args[0] assert status.is_succeeded() @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") -def test_wait_vm_operation_timeout(mock_session: MagicMock, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_timeout( + mock_session: MagicMock, azure_vm_service: AzureVMService +) -> None: """Test the time out of the remote VM operation.""" # Mock response header - params = { - "asyncResultsUrl": "DUMMY_ASYNC_URL", - "vmName": "test-vm", - "pollInterval": 1 - } + params = {"asyncResultsUrl": "DUMMY_ASYNC_URL", "vmName": "test-vm", "pollInterval": 1} mock_status_response = MagicMock(status_code=200) mock_status_response.json.return_value = { @@ -196,23 +221,31 @@ def test_wait_vm_operation_timeout(mock_session: MagicMock, @pytest.mark.parametrize( - ("total_retries", "operation_status"), [ + ("total_retries", "operation_status"), + [ (2, Status.SUCCEEDED), (1, Status.FAILED), (0, Status.FAILED), - ]) + ], +) @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") -def test_wait_vm_operation_retry(mock_getconn: MagicMock, - total_retries: int, - operation_status: Status, - azure_vm_service: AzureVMService) -> None: +def test_wait_vm_operation_retry( + mock_getconn: MagicMock, + total_retries: int, + operation_status: Status, + azure_vm_service: AzureVMService, +) -> None: """Test the retries of the remote VM operation.""" # Simulate intermittent connection issues with multiple connection errors # Sufficient retry attempts should result in success, otherwise a graceful failure state mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), - requests_ex.ConnectionError("Connection aborted", OSError(107, "Transport endpoint is not connected")), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), + requests_ex.ConnectionError( + "Connection aborted", OSError(107, "Transport endpoint is not connected") + ), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), ] @@ -223,38 +256,50 @@ def test_wait_vm_operation_retry(mock_getconn: MagicMock, "requestTotalRetries": total_retries, "asyncResultsUrl": "https://DUMMY_ASYNC_URL", "vmName": "test-vm", - }) + } + ) assert status == operation_status @pytest.mark.parametrize( - ("http_status_code", "operation_status"), [ + ("http_status_code", "operation_status"), + [ (200, Status.SUCCEEDED), (202, Status.PENDING), (401, Status.FAILED), (404, Status.FAILED), - ]) + ], +) @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_status(mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService, - http_status_code: int, operation_status: Status) -> None: +def test_remote_exec_status( + mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService, + http_status_code: int, + operation_status: Status, +) -> None: """Test waiting for completion of the remote execution on Azure.""" script = ["command_1", "command_2"] mock_response = MagicMock() mock_response.status_code = http_status_code - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - status, _ = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={}) + status, _ = azure_vm_service_remote_exec_only.remote_exec( + script, config={"vmName": "test-vm"}, env_params={} + ) assert status == operation_status @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") -def test_remote_exec_headers_output(mock_requests: MagicMock, - azure_vm_service_remote_exec_only: AzureVMService) -> None: +def test_remote_exec_headers_output( + mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService +) -> None: """Check if HTTP headers from the remote execution on Azure are correct.""" async_url_key = "asyncResultsUrl" async_url_value = "DUMMY_ASYNC_URL" @@ -262,18 +307,22 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, mock_response = MagicMock() mock_response.status_code = 202 - mock_response.headers = { - "Azure-AsyncOperation": async_url_value - } - mock_response.json = MagicMock(return_value={ - "fake response": "body as json to dict", - }) + mock_response.headers = {"Azure-AsyncOperation": async_url_value} + mock_response.json = MagicMock( + return_value={ + "fake response": "body as json to dict", + } + ) mock_requests.post.return_value = mock_response - _, cmd_output = azure_vm_service_remote_exec_only.remote_exec(script, config={"vmName": "test-vm"}, env_params={ - "param_1": 123, - "param_2": "abc", - }) + _, cmd_output = azure_vm_service_remote_exec_only.remote_exec( + script, + config={"vmName": "test-vm"}, + env_params={ + "param_1": 123, + "param_2": "abc", + }, + ) assert async_url_key in cmd_output assert cmd_output[async_url_key] == async_url_value @@ -281,15 +330,13 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, assert mock_requests.post.call_args[1]["json"] == { "commandId": "RunShellScript", "script": script, - "parameters": [ - {"name": "param_1", "value": 123}, - {"name": "param_2", "value": "abc"} - ] + "parameters": [{"name": "param_1", "value": 123}, {"name": "param_2", "value": "abc"}], } @pytest.mark.parametrize( - ("operation_status", "wait_output", "results_output"), [ + ("operation_status", "wait_output", "results_output"), + [ ( Status.SUCCEEDED, { @@ -301,13 +348,18 @@ def test_remote_exec_headers_output(mock_requests: MagicMock, } } }, - {"stdout": "DUMMY_STDOUT_STDERR"} + {"stdout": "DUMMY_STDOUT_STDERR"}, ), (Status.PENDING, {}, {}), (Status.FAILED, {}, {}), - ]) -def test_get_remote_exec_results(azure_vm_service_remote_exec_only: AzureVMService, operation_status: Status, - wait_output: dict, results_output: dict) -> None: + ], +) +def test_get_remote_exec_results( + azure_vm_service_remote_exec_only: AzureVMService, + operation_status: Status, + wait_output: dict, + results_output: dict, +) -> None: """Test getting the results of the remote execution on Azure.""" params = {"asyncResultsUrl": "DUMMY_ASYNC_URL"} diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 14a18f94ef..96cdc9f1d1 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -26,8 +26,9 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture -def azure_auth_service(config_persistence_service: ConfigPersistenceService, - monkeypatch: pytest.MonkeyPatch) -> AzureAuthService: +def azure_auth_service( + config_persistence_service: ConfigPersistenceService, monkeypatch: pytest.MonkeyPatch +) -> AzureAuthService: """Creates a dummy AzureAuthService for tests that require it.""" auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service) monkeypatch.setattr(auth, "get_access_token", lambda: "TEST_TOKEN") @@ -37,58 +38,78 @@ def azure_auth_service(config_persistence_service: ConfigPersistenceService, @pytest.fixture def azure_network_service(azure_auth_service: AzureAuthService) -> AzureNetworkService: """Creates a dummy Azure VM service for tests that require it.""" - return AzureNetworkService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureNetworkService( + config={ + "deploymentTemplatePath": ( + "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" + ), + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VNET", - "vnetName": "test-vnet", # Should come from the upper-level config - }, parent=azure_auth_service) + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VNET", + "vnetName": "test-vnet", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture def azure_vm_service(azure_auth_service: AzureAuthService) -> AzureVMService: """Creates a dummy Azure VM service for tests that require it.""" - return AzureVMService(config={ - "deploymentTemplatePath": "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc", - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "deploymentTemplateParameters": { - "location": "westus2", + return AzureVMService( + config={ + "deploymentTemplatePath": ( + "services/remote/azure/arm-templates/azuredeploy-ubuntu-vm.jsonc" + ), + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "deploymentTemplateParameters": { + "location": "westus2", + }, + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "deploymentName": "TEST_DEPLOYMENT-VM", + "vmName": "test-vm", # Should come from the upper-level config }, - "pollInterval": 1, - "pollTimeout": 2 - }, global_config={ - "deploymentName": "TEST_DEPLOYMENT-VM", - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + parent=azure_auth_service, + ) @pytest.fixture def azure_vm_service_remote_exec_only(azure_auth_service: AzureAuthService) -> AzureVMService: """Creates a dummy Azure VM service with no deployment template.""" - return AzureVMService(config={ - "subscription": "TEST_SUB", - "resourceGroup": "TEST_RG", - "pollInterval": 1, - "pollTimeout": 2, - }, global_config={ - "vmName": "test-vm", # Should come from the upper-level config - }, parent=azure_auth_service) + return AzureVMService( + config={ + "subscription": "TEST_SUB", + "resourceGroup": "TEST_RG", + "pollInterval": 1, + "pollTimeout": 2, + }, + global_config={ + "vmName": "test-vm", # Should come from the upper-level config + }, + parent=azure_auth_service, + ) @pytest.fixture def azure_fileshare(config_persistence_service: ConfigPersistenceService) -> AzureFileShareService: """Creates a dummy AzureFileShareService for tests that require it.""" with patch("mlos_bench.services.remote.azure.azure_fileshare.ShareClient"): - return AzureFileShareService(config={ - "storageAccountName": "TEST_ACCOUNT_NAME", - "storageFileShareName": "TEST_FS_NAME", - "storageAccountKey": "TEST_ACCOUNT_KEY" - }, global_config={}, parent=config_persistence_service) + return AzureFileShareService( + config={ + "storageAccountName": "TEST_ACCOUNT_NAME", + "storageFileShareName": "TEST_FS_NAME", + "storageAccountKey": "TEST_ACCOUNT_KEY", + }, + global_config={}, + parent=config_persistence_service, + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py index 9f75d79eac..482f9ee2a9 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_auth_service.py @@ -16,16 +16,24 @@ class MockAuthService(Service, SupportsAuth): """A collection Service functions for mocking authentication ops.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [ - self.get_access_token, - self.get_auth_headers, - ]) + config, + global_config, + parent, + self.merge_methods( + methods, + [ + self.get_access_token, + self.get_auth_headers, + ], + ), ) def get_access_token(self) -> str: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 5378e12837..abeb35f091 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -17,21 +17,30 @@ class MockFileShareService(FileShareService, SupportsFileShareOps): """A collection Service functions for mocking file share ops.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): super().__init__( - config, global_config, parent, - self.merge_methods(methods, [self.upload, self.download]) + config, + global_config, + parent, + self.merge_methods(methods, [self.upload, self.download]), ) self._upload: List[Tuple[str, str]] = [] self._download: List[Tuple[str, str]] = [] - def upload(self, params: dict, local_path: str, remote_path: str, recursive: bool = True) -> None: + def upload( + self, params: dict, local_path: str, remote_path: str, recursive: bool = True + ) -> None: self._upload.append((local_path, remote_path)) - def download(self, params: dict, remote_path: str, local_path: str, recursive: bool = True) -> None: + def download( + self, params: dict, remote_path: str, local_path: str, recursive: bool = True + ) -> None: self._download.append((remote_path, local_path)) def get_upload(self) -> List[Tuple[str, str]]: diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py index 03a02ba14e..a483432023 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_network_service.py @@ -16,10 +16,13 @@ class MockNetworkService(Service, SupportsNetworkProvisioning): """Mock Network service for testing.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock network services proxy. @@ -34,13 +37,19 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsNetworkProvisioning: - "provision_network", - "deprovision_network", - "wait_network_deployment", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsNetworkProvisioning: + "provision_network", + "deprovision_network", + "wait_network_deployment", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py index f1e29e5cd4..57f90ccd4d 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_remote_exec_service.py @@ -14,10 +14,13 @@ class MockRemoteExecService(Service, SupportsRemoteExec): """Mock remote script execution service.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock remote exec service. @@ -32,9 +35,14 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - "remote_exec": mock_operation, - "get_remote_exec_results": mock_operation, - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + "remote_exec": mock_operation, + "get_remote_exec_results": mock_operation, + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py index 1fe659a23f..0d093df48f 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_vm_service.py @@ -16,10 +16,13 @@ class MockVMService(Service, SupportsHostProvisioning, SupportsHostOps, SupportsOSOps): """Mock VM service for testing.""" - def __init__(self, config: Optional[Dict[str, Any]] = None, - global_config: Optional[Dict[str, Any]] = None, - parent: Optional[Service] = None, - methods: Union[Dict[str, Callable], List[Callable], None] = None): + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + global_config: Optional[Dict[str, Any]] = None, + parent: Optional[Service] = None, + methods: Union[Dict[str, Callable], List[Callable], None] = None, + ): """ Create a new instance of mock VM services proxy. @@ -34,23 +37,29 @@ def __init__(self, config: Optional[Dict[str, Any]] = None, Parent service that can provide mixin functions. """ super().__init__( - config, global_config, parent, - self.merge_methods(methods, { - name: mock_operation for name in ( - # SupportsHostProvisioning: - "wait_host_deployment", - "provision_host", - "deprovision_host", - "deallocate_host", - # SupportsHostOps: - "start_host", - "stop_host", - "restart_host", - "wait_host_operation", - # SupportsOsOps: - "shutdown", - "reboot", - "wait_os_operation", - ) - }) + config, + global_config, + parent, + self.merge_methods( + methods, + { + name: mock_operation + for name in ( + # SupportsHostProvisioning: + "wait_host_deployment", + "provision_host", + "deprovision_host", + "deallocate_host", + # SupportsHostOps: + "start_host", + "stop_host", + "restart_host", + "wait_host_operation", + # SupportsOsOps: + "shutdown", + "reboot", + "wait_os_operation", + ) + }, + ), ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py index 9d5e0ef153..78bd4b1bab 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/__init__.py @@ -15,9 +15,9 @@ # The SSH test server port and name. # See Also: docker-compose.yml SSH_TEST_SERVER_PORT = 2254 -SSH_TEST_SERVER_NAME = 'ssh-server' -ALT_TEST_SERVER_NAME = 'alt-server' -REBOOT_TEST_SERVER_NAME = 'reboot-server' +SSH_TEST_SERVER_NAME = "ssh-server" +ALT_TEST_SERVER_NAME = "alt-server" +REBOOT_TEST_SERVER_NAME = "reboot-server" @dataclass @@ -35,11 +35,19 @@ def get_port(self, uncached: bool = False) -> int: """ Gets the port that the SSH test server is listening on. - Note: this value can change when the service restarts so we can't rely on the DockerServices. + Note: this value can change when the service restarts so we can't rely on + the DockerServices. """ if self._port is None or uncached: - port_cmd = run(f"docker compose -p {self.compose_project_name} port {self.service_name} {SSH_TEST_SERVER_PORT}", - shell=True, check=True, capture_output=True) + port_cmd = run( + ( + f"docker compose -p {self.compose_project_name} " + f"port {self.service_name} {SSH_TEST_SERVER_PORT}" + ), + shell=True, + check=True, + capture_output=True, + ) self._port = int(port_cmd.stdout.decode().strip().split(":")[1]) return self._port diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 28c0367afa..913b045a76 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -30,26 +30,28 @@ # pylint: disable=redefined-outer-name -HOST_DOCKER_NAME = 'host.docker.internal' +HOST_DOCKER_NAME = "host.docker.internal" @pytest.fixture(scope="session") def ssh_test_server_hostname() -> str: """Returns the local hostname to use to connect to the test ssh server.""" - if sys.platform != 'win32' and resolve_host_name(HOST_DOCKER_NAME): + if sys.platform != "win32" and resolve_host_name(HOST_DOCKER_NAME): # On Linux, if we're running in a docker container, we can use the # --add-host (extra_hosts in docker-compose.yml) to refer to the host IP. return HOST_DOCKER_NAME # Docker (Desktop) for Windows (WSL2) uses a special networking magic # to refer to the host machine as `localhost` when exposing ports. # In all other cases, assume we're executing directly inside conda on the host. - return 'localhost' + return "localhost" @pytest.fixture(scope="session") -def ssh_test_server(ssh_test_server_hostname: str, - docker_compose_project_name: str, - locked_docker_services: DockerServices) -> Generator[SshTestServerInfo, None, None]: +def ssh_test_server( + ssh_test_server_hostname: str, + docker_compose_project_name: str, + locked_docker_services: DockerServices, +) -> Generator[SshTestServerInfo, None, None]: """ Fixture for getting the ssh test server services setup via docker-compose using pytest-docker. @@ -65,23 +67,38 @@ def ssh_test_server(ssh_test_server_hostname: str, compose_project_name=docker_compose_project_name, service_name=SSH_TEST_SERVER_NAME, hostname=ssh_test_server_hostname, - username='root', - id_rsa_path=id_rsa_file.name) - wait_docker_service_socket(locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port()) + username="root", + id_rsa_path=id_rsa_file.name, + ) + wait_docker_service_socket( + locked_docker_services, ssh_test_server_info.hostname, ssh_test_server_info.get_port() + ) id_rsa_src = f"/{ssh_test_server_info.username}/.ssh/id_rsa" - docker_cp_cmd = f"docker compose -p {docker_compose_project_name} cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" - cmd = run(docker_cp_cmd.split(), check=True, cwd=os.path.dirname(__file__), capture_output=True, text=True) + docker_cp_cmd = ( + f"docker compose -p {docker_compose_project_name} " + f"cp {SSH_TEST_SERVER_NAME}:{id_rsa_src} {id_rsa_file.name}" + ) + cmd = run( + docker_cp_cmd.split(), + check=True, + cwd=os.path.dirname(__file__), + capture_output=True, + text=True, + ) if cmd.returncode != 0: - raise RuntimeError(f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " - + f"[return={cmd.returncode}]: {str(cmd.stderr)}") + raise RuntimeError( + f"Failed to copy ssh key from {SSH_TEST_SERVER_NAME} container " + + f"[return={cmd.returncode}]: {str(cmd.stderr)}" + ) os.chmod(id_rsa_file.name, 0o600) yield ssh_test_server_info # NamedTempFile deleted on context exit @pytest.fixture(scope="session") -def alt_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def alt_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. @@ -95,14 +112,18 @@ def alt_test_server(ssh_test_server: SshTestServerInfo, service_name=ALT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, alt_test_server_info.hostname, alt_test_server_info.get_port() + ) return alt_test_server_info @pytest.fixture(scope="session") -def reboot_test_server(ssh_test_server: SshTestServerInfo, - locked_docker_services: DockerServices) -> SshTestServerInfo: +def reboot_test_server( + ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices +) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. @@ -116,8 +137,13 @@ def reboot_test_server(ssh_test_server: SshTestServerInfo, service_name=REBOOT_TEST_SERVER_NAME, hostname=ssh_test_server.hostname, username=ssh_test_server.username, - id_rsa_path=ssh_test_server.id_rsa_path) - wait_docker_service_socket(locked_docker_services, reboot_test_server_info.hostname, reboot_test_server_info.get_port()) + id_rsa_path=ssh_test_server.id_rsa_path, + ) + wait_docker_service_socket( + locked_docker_services, + reboot_test_server_info.hostname, + reboot_test_server_info.get_port(), + ) return reboot_test_server_info diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index 7b1a4e0756..a6a7c6149b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -50,8 +50,9 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker -def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_single_file( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -64,7 +65,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, lines = [line + "\n" for line in lines] # 1. Write a local file and upload it. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() temp_file.close() @@ -76,7 +77,7 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, ) # 2. Download the remote file and compare the contents. - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.close() ssh_fileshare_service.download( params=config, @@ -84,14 +85,15 @@ def test_ssh_fileshare_single_file(ssh_test_server: SshTestServerInfo, local_path=temp_file.name, ) # Download will replace the inode at that name, so we need to reopen the file. - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == lines @requires_docker -def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_recursive( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() @@ -111,14 +113,16 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, "bar", ], } - files_lines = {path: [line + "\n" for line in lines] for (path, lines) in files_lines.items()} + files_lines = { + path: [line + "\n" for line in lines] for (path, lines) in files_lines.items() + } with tempfile.TemporaryDirectory() as tempdir1, tempfile.TemporaryDirectory() as tempdir2: # Setup the directory structure. - for (file_path, lines) in files_lines.items(): + for file_path, lines in files_lines.items(): path = Path(tempdir1, file_path) path.parent.mkdir(parents=True, exist_ok=True) - with open(path, mode='w+t', encoding='utf-8') as temp_file: + with open(path, mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines(lines) temp_file.flush() assert os.path.getsize(path) > 0 @@ -145,15 +149,16 @@ def test_ssh_fileshare_recursive(ssh_test_server: SshTestServerInfo, @requires_docker -def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_download_file_dne( + ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService +) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() canary_str = "canary" - with closeable_temp_file(mode='w+t', encoding='utf-8') as temp_file: + with closeable_temp_file(mode="w+t", encoding="utf-8") as temp_file: temp_file.writelines([canary_str]) temp_file.flush() temp_file.close() @@ -164,20 +169,22 @@ def test_ssh_fileshare_download_file_dne(ssh_test_server: SshTestServerInfo, remote_path="/tmp/file-dne.txt", local_path=temp_file.name, ) - with open(temp_file.name, mode='r', encoding='utf-8') as temp_file_h: + with open(temp_file.name, mode="r", encoding="utf-8") as temp_file_h: read_lines = temp_file_h.readlines() assert read_lines == [canary_str] @requires_docker -def test_ssh_fileshare_upload_file_dne(ssh_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - ssh_fileshare_service: SshFileShareService) -> None: +def test_ssh_fileshare_upload_file_dne( + ssh_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + ssh_fileshare_service: SshFileShareService, +) -> None: """Test the SshFileShareService single file upload that doesn't exist.""" with ssh_host_service, ssh_fileshare_service: config = ssh_test_server.to_ssh_service_config() - path = '/tmp/upload-file-src-dne.txt' + path = "/tmp/upload-file-src-dne.txt" with pytest.raises(OSError): ssh_fileshare_service.upload( params=config, diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index ab7f3cd9e0..54ceb9984e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -25,9 +25,11 @@ @requires_docker -def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, - alt_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_remote_exec( + ssh_test_server: SshTestServerInfo, + alt_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """ Test the SshHostService remote_exec. @@ -40,7 +42,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, connection_id = SshClient.id_from_params(ssh_test_server.to_connect_params()) assert ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None - connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get(connection_id) + connection_client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache.get( + connection_id + ) assert connection_client is None (status, results_info) = ssh_host_service.remote_exec( @@ -55,7 +59,9 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, assert results["stdout"].strip() == SSH_TEST_SERVER_NAME # Check that the client caching is behaving as expected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection is not None assert connection._username == ssh_test_server.username assert connection._host == ssh_test_server.hostname @@ -70,7 +76,8 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, script=["hostname"], config=alt_test_server.to_ssh_service_config(), env_params={ - "UNUSED": "unused", # unused, making sure it doesn't carry over with cached connections + # unused, making sure it doesn't carry over with cached connections + "UNUSED": "unused", }, ) assert status.is_pending() @@ -89,13 +96,15 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) - assert status.is_failed() # should retain exit code from "false" + assert status.is_failed() # should retain exit code from "false" stdout = str(results["stdout"]) assert stdout.splitlines() == [ "BAR=bar", "UNUSED=", ] - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection._local_port == local_port # Close the connection (gracefully) @@ -112,7 +121,7 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, config=config, # Also test interacting with environment_variables. env_params={ - 'FOO': 'foo', + "FOO": "foo", }, ) status, results = ssh_host_service.get_remote_exec_results(results_info) @@ -125,17 +134,21 @@ def test_ssh_service_remote_exec(ssh_test_server: SshTestServerInfo, "BAZ=", ] # Make sure it looks like we reconnected. - connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[connection_id] + connection, client = ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE._cache[ + connection_id + ] assert connection._local_port != local_port # Make sure the cache is cleaned up on context exit. assert len(SshHostService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE) == 0 -def check_ssh_service_reboot(docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService, - graceful: bool) -> None: +def check_ssh_service_reboot( + docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, + graceful: bool, +) -> None: """Check the SshHostService reboot operation.""" # Note: rebooting changes the port number unfortunately, but makes it # easier to check for success. @@ -144,11 +157,7 @@ def check_ssh_service_reboot(docker_services: DockerServices, with ssh_host_service: reboot_test_srv_ssh_svc_conf = reboot_test_server.to_ssh_service_config(uncached=True) (status, results_info) = ssh_host_service.remote_exec( - script=[ - 'echo "sleeping..."', - 'sleep 30', - 'echo "should not reach this point"' - ], + script=['echo "sleeping..."', "sleep 30", 'echo "should not reach this point"'], config=reboot_test_srv_ssh_svc_conf, env_params={}, ) @@ -157,8 +166,9 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # Now try to restart the server. - (status, reboot_results_info) = ssh_host_service.reboot(params=reboot_test_srv_ssh_svc_conf, - force=not graceful) + (status, reboot_results_info) = ssh_host_service.reboot( + params=reboot_test_srv_ssh_svc_conf, force=not graceful + ) assert status.is_pending() (status, reboot_results_info) = ssh_host_service.wait_os_operation(reboot_results_info) @@ -179,19 +189,34 @@ def check_ssh_service_reboot(docker_services: DockerServices, time.sleep(1) # try to reconnect and see if the port changed try: - run_res = run("docker ps | grep mlos_bench-test- | grep reboot", shell=True, capture_output=True, check=False) + run_res = run( + "docker ps | grep mlos_bench-test- | grep reboot", + shell=True, + capture_output=True, + check=False, + ) print(run_res.stdout.decode()) print(run_res.stderr.decode()) - reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config(uncached=True) - if reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"]: + reboot_test_srv_ssh_svc_conf_new = reboot_test_server.to_ssh_service_config( + uncached=True + ) + if ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ): break except CalledProcessError as ex: _LOG.info("Failed to check port for reboot test server: %s", ex) - assert reboot_test_srv_ssh_svc_conf_new["ssh_port"] != reboot_test_srv_ssh_svc_conf["ssh_port"] + assert ( + reboot_test_srv_ssh_svc_conf_new["ssh_port"] + != reboot_test_srv_ssh_svc_conf["ssh_port"] + ) - wait_docker_service_socket(docker_services, - reboot_test_server.hostname, - reboot_test_srv_ssh_svc_conf_new["ssh_port"]) + wait_docker_service_socket( + docker_services, + reboot_test_server.hostname, + reboot_test_srv_ssh_svc_conf_new["ssh_port"], + ) (status, results_info) = ssh_host_service.remote_exec( script=["hostname"], @@ -204,10 +229,16 @@ def check_ssh_service_reboot(docker_services: DockerServices, @requires_docker -def test_ssh_service_reboot(locked_docker_services: DockerServices, - reboot_test_server: SshTestServerInfo, - ssh_host_service: SshHostService) -> None: +def test_ssh_service_reboot( + locked_docker_services: DockerServices, + reboot_test_server: SshTestServerInfo, + ssh_host_service: SshHostService, +) -> None: """Test the SshHostService reboot operation.""" # Grouped together to avoid parallel runner interactions. - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=True) - check_ssh_service_reboot(locked_docker_services, reboot_test_server, ssh_host_service, graceful=False) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=True + ) + check_ssh_service_reboot( + locked_docker_services, reboot_test_server, ssh_host_service, graceful=False + ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py index 1eabd7ea37..5b335477a9 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_service.py @@ -33,7 +33,9 @@ # We replaced pytest-lazy-fixture with pytest-lazy-fixtures: # https://github.com/TvoroG/pytest-lazy-fixture/issues/65 if version("pytest-lazy-fixture"): - raise UserWarning("pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it.") + raise UserWarning( + "pytest-lazy-fixture conflicts with pytest>=8.0.0. Please remove it." + ) except PackageNotFoundError: # OK: pytest-lazy-fixture not installed pass @@ -41,12 +43,14 @@ @requires_docker @requires_ssh -@pytest.mark.parametrize(["ssh_test_server_info", "server_name"], [ - (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), - (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), -]) -def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, - server_name: str) -> None: +@pytest.mark.parametrize( + ["ssh_test_server_info", "server_name"], + [ + (lazy_fixture("ssh_test_server"), SSH_TEST_SERVER_NAME), + (lazy_fixture("alt_test_server"), ALT_TEST_SERVER_NAME), + ], +) +def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, server_name: str) -> None: """Check for the pytest-docker ssh test infra.""" assert ssh_test_server_info.service_name == server_name @@ -55,17 +59,18 @@ def test_ssh_service_test_infra(ssh_test_server_info: SshTestServerInfo, local_port = ssh_test_server_info.get_port() assert check_socket(ip_addr, local_port) - ssh_cmd = "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " \ - + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " \ + ssh_cmd = ( + "ssh -o BatchMode=yes -o StrictHostKeyChecking=accept-new " + + f"-l {ssh_test_server_info.username} -i {ssh_test_server_info.id_rsa_path} " + f"-p {local_port} {ssh_test_server_info.hostname} hostname" - cmd = run(ssh_cmd.split(), - capture_output=True, - text=True, - check=True) + ) + cmd = run(ssh_cmd.split(), capture_output=True, text=True, check=True) assert cmd.stdout.strip() == server_name -@pytest.mark.filterwarnings("ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0") +@pytest.mark.filterwarnings( + "ignore:.*(coroutine 'sleep' was never awaited).*:RuntimeWarning:.*event_loop_context_test.*:0" +) def test_ssh_service_context_handler() -> None: """ Test the SSH service context manager handling. @@ -86,7 +91,9 @@ def test_ssh_service_context_handler() -> None: # After we enter the SshService instance context, we should have a background thread. with ssh_host_service: assert ssh_host_service._in_context - assert isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) # type: ignore[unreachable] + assert ( # type: ignore[unreachable] + isinstance(SshService._EVENT_LOOP_CONTEXT._event_loop_thread, Thread) + ) # Give the thread a chance to start. # Mostly important on the underpowered Windows CI machines. time.sleep(0.25) @@ -99,17 +106,23 @@ def test_ssh_service_context_handler() -> None: with ssh_fileshare_service: assert ssh_fileshare_service._in_context assert ssh_host_service._in_context - assert SshService._EVENT_LOOP_CONTEXT._event_loop_thread \ - is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread \ + assert ( + SshService._EVENT_LOOP_CONTEXT._event_loop_thread + is ssh_host_service._EVENT_LOOP_CONTEXT._event_loop_thread is ssh_fileshare_service._EVENT_LOOP_CONTEXT._event_loop_thread - assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ - is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE \ + ) + assert ( + SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + is ssh_host_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is ssh_fileshare_service._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE + ) assert not ssh_fileshare_service._in_context # And that instance should be unusable after we are outside the context. - with pytest.raises(AssertionError): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): - future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result='foo')) + with pytest.raises( + AssertionError + ): # , pytest.warns(RuntimeWarning, match=r".*coroutine 'sleep' was never awaited"): + future = ssh_fileshare_service._run_coroutine(asyncio.sleep(0.1, result="foo")) raise ValueError(f"Future should not have been available to wait on {future.result()}") # The background thread should remain running since we have another context still open. @@ -117,6 +130,6 @@ def test_ssh_service_context_handler() -> None: assert SshService._EVENT_LOOP_THREAD_SSH_CLIENT_CACHE is not None -if __name__ == '__main__': +if __name__ == "__main__": # For debugging in Windows which has issues with pytest detection in vscode. pytest.main(["-n1", "--dist=no", "-k", "test_ssh_service_background_thread"]) diff --git a/mlos_bench/mlos_bench/tests/storage/conftest.py b/mlos_bench/mlos_bench/tests/storage/conftest.py index 879be9497a..52b0fdcd53 100644 --- a/mlos_bench/mlos_bench/tests/storage/conftest.py +++ b/mlos_bench/mlos_bench/tests/storage/conftest.py @@ -17,7 +17,9 @@ mixed_numerics_exp_storage = sql_storage_fixtures.mixed_numerics_exp_storage exp_storage_with_trials = sql_storage_fixtures.exp_storage_with_trials exp_no_tunables_storage_with_trials = sql_storage_fixtures.exp_no_tunables_storage_with_trials -mixed_numerics_exp_storage_with_trials = sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +mixed_numerics_exp_storage_with_trials = ( + sql_storage_fixtures.mixed_numerics_exp_storage_with_trials +) exp_data = sql_storage_fixtures.exp_data exp_no_tunables_data = sql_storage_fixtures.exp_no_tunables_data mixed_numerics_exp_data = sql_storage_fixtures.mixed_numerics_exp_data diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 941683333e..e6ef30db6a 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -18,21 +18,30 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) assert exp.objectives == exp_storage.opt_targets -def test_exp_data_root_env_config(exp_storage: Storage.Experiment, exp_data: ExperimentData) -> None: +def test_exp_data_root_env_config( + exp_storage: Storage.Experiment, exp_data: ExperimentData +) -> None: """Tests the root_env_config property of ExperimentData.""" # pylint: disable=protected-access - assert exp_data.root_env_config == (exp_storage._root_env_config, exp_storage._git_repo, exp_storage._git_commit) + assert exp_data.root_env_config == ( + exp_storage._root_env_config, + exp_storage._git_repo, + exp_storage._git_commit, + ) -def test_exp_trial_data_objectives(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_data_objectives( + storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups +) -> None: """Start a new trial and check the storage for the trial data.""" - trial_opt_new = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "some-other-target", - "opt_direction": "max", - }) + trial_opt_new = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "some-other-target", + "opt_direction": "max", + }, + ) assert trial_opt_new.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_new.trial_id, @@ -40,10 +49,13 @@ def test_exp_trial_data_objectives(storage: Storage, "opt_direction": "max", } - trial_opt_old = exp_storage.new_trial(tunable_groups, config={ - "opt_target": "back-compat", - # "opt_direction": "max", # missing - }) + trial_opt_old = exp_storage.new_trial( + tunable_groups, + config={ + "opt_target": "back-compat", + # "opt_direction": "max", # missing + }, + ) assert trial_opt_old.config() == { "experiment_id": exp_storage.experiment_id, "trial_id": trial_opt_old.trial_id, @@ -68,9 +80,14 @@ def test_exp_data_results_df(exp_data: ExperimentData, tunable_groups: TunableGr assert len(results_df["tunable_config_id"].unique()) == CONFIG_COUNT assert len(results_df["trial_id"].unique()) == expected_trials_count obj_target = next(iter(exp_data.objectives)) - assert len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count + assert ( + len(results_df[ExperimentData.RESULT_COLUMN_PREFIX + obj_target]) == expected_trials_count + ) (tunable, _covariant_group) = next(iter(tunable_groups)) - assert len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) == expected_trials_count + assert ( + len(results_df[ExperimentData.CONFIG_COLUMN_PREFIX + tunable.name]) + == expected_trials_count + ) def test_exp_data_tunable_config_trial_group_id_in_results_df(exp_data: ExperimentData) -> None: @@ -110,13 +127,15 @@ def test_exp_data_tunable_config_trial_groups(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_config_trial_groups.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config_trial_group.tunable_config_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [ + config_trial_group.tunable_config_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT + 1)) # And the tunable_config_trial_group_id should also match the minimum trial_id. - assert [config_trial_group.tunable_config_trial_group_id - for config_trial_group in exp_data.tunable_config_trial_groups.values() - ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) + assert [ + config_trial_group.tunable_config_trial_group_id + for config_trial_group in exp_data.tunable_config_trial_groups.values() + ] == list(range(1, CONFIG_COUNT * CONFIG_TRIAL_REPEAT_COUNT, CONFIG_TRIAL_REPEAT_COUNT)) def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: @@ -124,9 +143,9 @@ def test_exp_data_tunable_configs(exp_data: ExperimentData) -> None: # Should be keyed by config_id. assert list(exp_data.tunable_configs.keys()) == list(range(1, CONFIG_COUNT + 1)) # Which should match the objects. - assert [config.tunable_config_id - for config in exp_data.tunable_configs.values() - ] == list(range(1, CONFIG_COUNT + 1)) + assert [config.tunable_config_id for config in exp_data.tunable_configs.values()] == list( + range(1, CONFIG_COUNT + 1) + ) def test_exp_data_default_config_id(exp_data: ExperimentData) -> None: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index d69a580b9e..0cbd02ae97 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -31,9 +31,9 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Start a trial and check that it is pending.""" trial = exp_storage.new_trial(tunable_groups) (pending,) = list(exp_storage.pending_trials(datetime.now(zone_info), running=True)) @@ -42,12 +42,12 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_many(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_many( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Start THREE trials and check that both are pending.""" - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) - config2 = tunable_groups.copy().assign({'idle': 'noidle'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) + config2 = tunable_groups.copy().assign({"idle": "noidle"}) trial_ids = { exp_storage.new_trial(config1).trial_id, exp_storage.new_trial(config2).trial_id, @@ -62,9 +62,9 @@ def test_exp_trial_pending_many(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_fail( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Start a trial, fail it, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.FAILED, datetime.now(zone_info)) @@ -73,9 +73,9 @@ def test_exp_trial_pending_fail(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_success(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_success( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Start a trial, finish it successfully, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9}) @@ -84,29 +84,31 @@ def test_exp_trial_success(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_categ(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_categ( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Update the trial with multiple metrics, some of which are categorical.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.SUCCEEDED, datetime.now(zone_info), {"score": 99.9, "benchmark": "test"}) assert exp_storage.load() == ( [trial.trial_id], - [{ - 'idle': 'halt', - 'kernel_sched_latency_ns': '2000000', - 'kernel_sched_migration_cost_ns': '-1', - 'vmSize': 'Standard_B4ms' - }], + [ + { + "idle": "halt", + "kernel_sched_latency_ns": "2000000", + "kernel_sched_migration_cost_ns": "-1", + "vmSize": "Standard_B4ms", + } + ], [{"score": "99.9", "benchmark": "test"}], - [Status.SUCCEEDED] + [Status.SUCCEEDED], ) @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_update_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_update_twice( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """Update the trial status twice and receive an error.""" trial = exp_storage.new_trial(tunable_groups) trial.update(Status.FAILED, datetime.now(zone_info)) @@ -115,9 +117,9 @@ def test_exp_trial_update_twice(exp_storage: Storage.Experiment, @pytest.mark.parametrize(("zone_info"), ZONE_INFO) -def test_exp_trial_pending_3(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - zone_info: Optional[tzinfo]) -> None: +def test_exp_trial_pending_3( + exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] +) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index 839404ff0b..fa26245e78 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -32,7 +32,7 @@ def storage() -> SqlStorage: "drivername": "sqlite", "database": ":memory:", # "database": "mlos_bench.pytest.db", - } + }, ) @@ -106,7 +106,9 @@ def mixed_numerics_exp_storage( assert not exp._in_context -def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> SqlStorage.Experiment: +def _dummy_run_exp( + exp: SqlStorage.Experiment, tunable_name: Optional[str] +) -> SqlStorage.Experiment: """Generates data by doing a simulated run of the given experiment.""" # Add some trials to that experiment. # Note: we're just fabricating some made up function for the ML libraries to try and learn. @@ -117,24 +119,31 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S (tunable_min, tunable_max) = tunable.range tunable_range = tunable_max - tunable_min rand_seed(SEED) - opt = MockOptimizer(tunables=exp.tunables, config={ - "seed": SEED, - # This should be the default, so we leave it omitted for now to test the default. - # But the test logic relies on this (e.g., trial 1 is config 1 is the default values for the tunable params) - # "start_with_defaults": True, - }) + opt = MockOptimizer( + tunables=exp.tunables, + config={ + "seed": SEED, + # This should be the default, so we leave it omitted for now to test the default. + # But the test logic relies on this (e.g., trial 1 is config 1 is the + # default values for the tunable params) + # "start_with_defaults": True, + }, + ) assert opt.start_with_defaults for config_i in range(CONFIG_COUNT): tunables = opt.suggest() for repeat_j in range(CONFIG_TRIAL_REPEAT_COUNT): - trial = exp.new_trial(tunables=tunables.copy(), config={ - "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, - **{ - f"opt_{key}_{i}": val - for (i, opt_target) in enumerate(exp.opt_targets.items()) - for (key, val) in zip(["target", "direction"], opt_target) - } - }) + trial = exp.new_trial( + tunables=tunables.copy(), + config={ + "trial_number": config_i * CONFIG_TRIAL_REPEAT_COUNT + repeat_j + 1, + **{ + f"opt_{key}_{i}": val + for (i, opt_target) in enumerate(exp.opt_targets.items()) + for (key, val) in zip(["target", "direction"], opt_target) + }, + }, + ) if exp.tunables: assert trial.tunable_config_id == config_i + 1 else: @@ -145,14 +154,23 @@ def _dummy_run_exp(exp: SqlStorage.Experiment, tunable_name: Optional[str]) -> S else: tunable_value_norm = 0 timestamp = datetime.now(UTC) - trial.update_telemetry(status=Status.RUNNING, timestamp=timestamp, metrics=[ - (timestamp, "some-metric", tunable_value_norm + random() / 100), - ]) - trial.update(Status.SUCCEEDED, timestamp, metrics={ - # Give some variance on the score. - # And some influence from the tunable value. - "score": tunable_value_norm + random() / 100 - }) + trial.update_telemetry( + status=Status.RUNNING, + timestamp=timestamp, + metrics=[ + (timestamp, "some-metric", tunable_value_norm + random() / 100), + ], + ) + trial.update( + Status.SUCCEEDED, + timestamp, + metrics={ + # Give some variance on the score. + # And some influence from the tunable value. + "score": tunable_value_norm + + random() / 100 + }, + ) return exp @@ -163,32 +181,42 @@ def exp_storage_with_trials(exp_storage: SqlStorage.Experiment) -> SqlStorage.Ex @pytest.fixture -def exp_no_tunables_storage_with_trials(exp_no_tunables_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def exp_no_tunables_storage_with_trials( + exp_no_tunables_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """Test fixture for Experiment using in-memory SQLite3 storage.""" assert not exp_no_tunables_storage.tunables return _dummy_run_exp(exp_no_tunables_storage, tunable_name=None) @pytest.fixture -def mixed_numerics_exp_storage_with_trials(mixed_numerics_exp_storage: SqlStorage.Experiment) -> SqlStorage.Experiment: +def mixed_numerics_exp_storage_with_trials( + mixed_numerics_exp_storage: SqlStorage.Experiment, +) -> SqlStorage.Experiment: """Test fixture for Experiment using in-memory SQLite3 storage.""" tunable = next(iter(mixed_numerics_exp_storage.tunables))[0] return _dummy_run_exp(mixed_numerics_exp_storage, tunable_name=tunable.name) @pytest.fixture -def exp_data(storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_data( + storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """Test fixture for ExperimentData.""" return storage.experiments[exp_storage_with_trials.experiment_id] @pytest.fixture -def exp_no_tunables_data(storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def exp_no_tunables_data( + storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """Test fixture for ExperimentData with no tunable configs.""" return storage.experiments[exp_no_tunables_storage_with_trials.experiment_id] @pytest.fixture -def mixed_numerics_exp_data(storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment) -> ExperimentData: +def mixed_numerics_exp_data( + storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment +) -> ExperimentData: """Test fixture for ExperimentData with mixed numerical tunable types.""" return storage.experiments[mixed_numerics_exp_storage_with_trials.experiment_id] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py index 851993f4a2..b5f4778a74 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_config_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_config_test.py @@ -11,8 +11,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_exp_trial_pending(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_pending(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """Schedule a trial and check that it is pending and has the right configuration.""" config = {"location": "westus2", "num_repeats": 100} trial = exp_storage.new_trial(tunable_groups, config=config) @@ -27,12 +26,11 @@ def test_exp_trial_pending(exp_storage: Storage.Experiment, } -def test_exp_trial_configs(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_exp_trial_configs(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """Start multiple trials with two different configs and check that we store only two config objects in the DB. """ - config1 = tunable_groups.copy().assign({'idle': 'mwait'}) + config1 = tunable_groups.copy().assign({"idle": "mwait"}) trials1 = [ exp_storage.new_trial(config1), exp_storage.new_trial(config1), @@ -41,7 +39,7 @@ def test_exp_trial_configs(exp_storage: Storage.Experiment, assert trials1[0].tunable_config_id == trials1[1].tunable_config_id assert trials1[0].tunable_config_id == trials1[2].tunable_config_id - config2 = tunable_groups.copy().assign({'idle': 'halt'}) + config2 = tunable_groups.copy().assign({"idle": "halt"}) trials2 = [ exp_storage.new_trial(config2), exp_storage.new_trial(config2), diff --git a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py index 628051a373..0a4d72480d 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_schedule_test.py @@ -18,8 +18,7 @@ def _trial_ids(trials: Iterator[Storage.Trial]) -> Set[int]: return set(t.trial_id for t in trials) -def test_schedule_trial(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups) -> None: +def test_schedule_trial(exp_storage: Storage.Experiment, tunable_groups: TunableGroups) -> None: """Schedule several trials for future execution and retrieve them later at certain timestamps. """ @@ -40,16 +39,14 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Scheduler side: get trials ready to run at certain timestamps: # Pretend 1 minute has passed, get trials scheduled to run: - pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) + pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1min, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, } # Get trials scheduled to run within the next 1 hour: - pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) + pending_ids = _trial_ids(exp_storage.pending_trials(timestamp + timedelta_1hr, running=False)) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -58,7 +55,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -80,7 +78,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=False) + ) assert pending_ids == { trial_1h.trial_id, trial_2h.trial_id, @@ -88,7 +87,8 @@ def test_schedule_trial(exp_storage: Storage.Experiment, # Get trials scheduled to run OR running within the next 3 hours: pending_ids = _trial_ids( - exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True)) + exp_storage.pending_trials(timestamp + timedelta_1hr * 3, running=True) + ) assert pending_ids == { trial_now1.trial_id, trial_now2.trial_id, @@ -110,7 +110,9 @@ def test_schedule_trial(exp_storage: Storage.Experiment, assert trial_status == [Status.SUCCEEDED, Status.FAILED, Status.SUCCEEDED] # Get only trials completed after trial_now2: - (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load(last_trial_id=trial_now2.trial_id) + (trial_ids, trial_configs, trial_scores, trial_status) = exp_storage.load( + last_trial_id=trial_now2.trial_id + ) assert trial_ids == [trial_1h.trial_id] assert len(trial_configs) == len(trial_scores) == 1 assert trial_status == [Status.SUCCEEDED] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index cffaaac4c6..aeb3d9fbee 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -29,28 +29,33 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st """ timestamp1 = datetime.now(zone_info) timestamp2 = timestamp1 + timedelta(seconds=1) - return sorted([ - (timestamp1, "cpu_load", 10.1), - (timestamp1, "memory", 20), - (timestamp1, "setup", "prod"), - (timestamp2, "cpu_load", 30.1), - (timestamp2, "memory", 40), - (timestamp2, "setup", "prod"), - ]) + return sorted( + [ + (timestamp1, "cpu_load", 10.1), + (timestamp1, "memory", 20), + (timestamp1, "setup", "prod"), + (timestamp2, "cpu_load", 30.1), + (timestamp2, "memory", 40), + (timestamp2, "setup", "prod"), + ] + ) -def _telemetry_str(data: List[Tuple[datetime, str, Any]] - ) -> List[Tuple[datetime, str, Optional[str]]]: +def _telemetry_str( + data: List[Tuple[datetime, str, Any]] +) -> List[Tuple[datetime, str, Optional[str]]]: """Convert telemetry values to strings.""" # All retrieved timestamps should have been converted to UTC. return [(ts.astimezone(UTC), key, nullable(str, val)) for (ts, key, val) in data] @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry(storage: Storage, - exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry( + storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """Make sure update_telemetry() and load_telemetry() methods work.""" telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) @@ -67,9 +72,11 @@ def test_update_telemetry(storage: Storage, @pytest.mark.parametrize(("origin_zone_info"), ZONE_INFO) -def test_update_telemetry_twice(exp_storage: Storage.Experiment, - tunable_groups: TunableGroups, - origin_zone_info: Optional[tzinfo]) -> None: +def test_update_telemetry_twice( + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + origin_zone_info: Optional[tzinfo], +) -> None: """Make sure update_telemetry() call is idempotent.""" telemetry_data = zoned_telemetry_data(origin_zone_info) trial = exp_storage.new_trial(tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index ea13f63ea5..20ed746462 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -8,8 +8,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_trial_data_tunable_config_data(exp_data: ExperimentData, - tunable_groups: TunableGroups) -> None: +def test_trial_data_tunable_config_data( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """Check expected return values for TunableConfigData.""" trial_id = 1 expected_config_id = 1 @@ -23,12 +24,12 @@ def test_trial_data_tunable_config_data(exp_data: ExperimentData, def test_trial_metadata(exp_data: ExperimentData) -> None: """Check expected return values for TunableConfigData metadata.""" - assert exp_data.objectives == {'score': 'min'} - for (trial_id, trial) in exp_data.trials.items(): + assert exp_data.objectives == {"score": "min"} + for trial_id, trial in exp_data.trials.items(): assert trial.metadata_dict == { - 'opt_target_0': 'score', - 'opt_direction_0': 'min', - 'trial_number': trial_id, + "opt_target_0": "score", + "opt_direction_0": "min", + "trial_number": trial_id, } @@ -40,12 +41,12 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, - mixed_numerics_tunable_groups: TunableGroups) -> None: + mixed_numerics_exp_data: ExperimentData, mixed_numerics_tunable_groups: TunableGroups +) -> None: """Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. """ trial = next(iter(mixed_numerics_exp_data.trials.values())) config = trial.tunable_config.config_dict - for (tunable, _group) in mixed_numerics_tunable_groups: + for tunable, _group in mixed_numerics_tunable_groups: assert isinstance(config[tunable.name], tunable.dtype) diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index 0646129e42..b8d83d5c32 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -14,10 +14,15 @@ def test_tunable_config_trial_group_data(exp_data: ExperimentData) -> None: trial_id = 1 trial = exp_data.trials[trial_id] tunable_config_trial_group = trial.tunable_config_trial_group - assert tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id + assert ( + tunable_config_trial_group.experiment_id == exp_data.experiment_id == trial.experiment_id + ) assert tunable_config_trial_group.tunable_config_id == trial.tunable_config_id assert tunable_config_trial_group.tunable_config == trial.tunable_config - assert tunable_config_trial_group == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group + assert ( + tunable_config_trial_group + == next(iter(tunable_config_trial_group.trials.values())).tunable_config_trial_group + ) def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) -> None: @@ -47,7 +52,9 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) # And so on ... -def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable_groups: TunableGroups) -> None: +def test_tunable_config_trial_group_results_df( + exp_data: ExperimentData, tunable_groups: TunableGroups +) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 expected_group_id = 4 @@ -56,9 +63,14 @@ def test_tunable_config_trial_group_results_df(exp_data: ExperimentData, tunable # We shouldn't have the results for the other configs, just this one. expected_count = CONFIG_TRIAL_REPEAT_COUNT assert len(results_df) == expected_count - assert len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count + assert ( + len(results_df[(results_df["tunable_config_id"] == tunable_config_id)]) == expected_count + ) assert len(results_df[(results_df["tunable_config_id"] != tunable_config_id)]) == 0 - assert len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) == expected_count + assert ( + len(results_df[(results_df["tunable_config_trial_group_id"] == expected_group_id)]) + == expected_count + ) assert len(results_df[(results_df["tunable_config_trial_group_id"] != expected_group_id)]) == 0 assert len(results_df["trial_id"].unique()) == expected_count obj_target = next(iter(exp_data.objectives)) @@ -74,8 +86,14 @@ def test_tunable_config_trial_group_trials(exp_data: ExperimentData) -> None: tunable_config_trial_group = exp_data.tunable_config_trial_groups[tunable_config_id] trials = tunable_config_trial_group.trials assert len(trials) == CONFIG_TRIAL_REPEAT_COUNT - assert all(trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id - for trial in trials.values()) - assert all(trial.tunable_config_id == tunable_config_id - for trial in tunable_config_trial_group.trials.values()) - assert exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] + assert all( + trial.tunable_config_trial_group.tunable_config_trial_group_id == expected_group_id + for trial in trials.values() + ) + assert all( + trial.tunable_config_id == tunable_config_id + for trial in tunable_config_trial_group.trials.values() + ) + assert ( + exp_data.trials[expected_group_id] == tunable_config_trial_group.trials[expected_group_id] + ) diff --git a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py index cd7edcd005..87a4dcb0ba 100644 --- a/mlos_bench/mlos_bench/tests/test_with_alt_tz.py +++ b/mlos_bench/mlos_bench/tests/test_with_alt_tz.py @@ -22,7 +22,7 @@ ] -@pytest.mark.skipif(sys.platform == 'win32', reason="TZ environment variable is a UNIXism") +@pytest.mark.skipif(sys.platform == "win32", reason="TZ environment variable is a UNIXism") @pytest.mark.parametrize(("tz_name"), ZONE_NAMES) @pytest.mark.parametrize(("test_file"), TZ_TEST_FILES) def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: @@ -41,4 +41,6 @@ def test_trial_telemetry_alt_tz(tz_name: Optional[str], test_file: str) -> None: if cmd.returncode != 0: print(cmd.stdout.decode()) print(cmd.stderr.decode()) - raise AssertionError(f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'") + raise AssertionError( + f"Test(s) failed: # TZ='{tz_name}' '{sys.executable}' -m pytest -n0 '{test_file}'" + ) diff --git a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py index 64ab724be8..5f31eaef23 100644 --- a/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py +++ b/mlos_bench/mlos_bench/tests/tunable_groups_fixtures.py @@ -115,24 +115,26 @@ def mixed_numerics_tunable_groups() -> TunableGroups: tunable_groups : TunableGroups A new TunableGroups object for testing. """ - tunables = TunableGroups({ - "mix-numerics": { - "cost": 1, - "params": { - "int": { - "description": "An integer", - "type": "int", - "default": 0, - "range": [0, 100], + tunables = TunableGroups( + { + "mix-numerics": { + "cost": 1, + "params": { + "int": { + "description": "An integer", + "type": "int", + "default": 0, + "range": [0, 100], + }, + "float": { + "description": "A float", + "type": "float", + "default": 0, + "range": [0, 1], + }, }, - "float": { - "description": "A float", - "type": "float", - "default": 0, - "range": [0, 1], - }, - } - }, - }) + }, + } + ) tunables.reset() return tunables diff --git a/mlos_bench/mlos_bench/tests/tunables/conftest.py b/mlos_bench/mlos_bench/tests/tunables/conftest.py index f5b1629c9f..054e8c7d87 100644 --- a/mlos_bench/mlos_bench/tests/tunables/conftest.py +++ b/mlos_bench/mlos_bench/tests/tunables/conftest.py @@ -23,12 +23,15 @@ def tunable_categorical() -> Tunable: tunable : Tunable An instance of a categorical Tunable. """ - return Tunable("vmSize", { - "description": "Azure VM size", - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - }) + return Tunable( + "vmSize", + { + "description": "Azure VM size", + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + }, + ) @pytest.fixture @@ -41,13 +44,16 @@ def tunable_int() -> Tunable: tunable : Tunable An instance of an integer Tunable. """ - return Tunable("kernel_sched_migration_cost_ns", { - "description": "Cost of migrating the thread to another core", - "type": "int", - "default": 40000, - "range": [0, 500000], - "special": [-1] # Special value outside of the range - }) + return Tunable( + "kernel_sched_migration_cost_ns", + { + "description": "Cost of migrating the thread to another core", + "type": "int", + "default": 40000, + "range": [0, 500000], + "special": [-1], # Special value outside of the range + }, + ) @pytest.fixture @@ -60,9 +66,12 @@ def tunable_float() -> Tunable: tunable : Tunable An instance of a float Tunable. """ - return Tunable("chaos_monkey_prob", { - "description": "Probability of spontaneous VM shutdown", - "type": "float", - "default": 0.01, - "range": [0, 1] - }) + return Tunable( + "chaos_monkey_prob", + { + "description": "Probability of spontaneous VM shutdown", + "type": "float", + "default": 0.01, + "range": [0, 1], + }, + ) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py index 28f92d4769..5ec31743bd 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunable_categoricals.py @@ -35,7 +35,7 @@ def test_tunable_categorical_types() -> None: "values": ["a", "b", "c"], "default": "a", }, - } + }, } } tunable_groups = TunableGroups(tunable_params) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index 768be65cb2..c792e82bcd 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -21,7 +21,8 @@ def test_tunable_int_size_props() -> None: "type": "int", "range": [1, 5], "default": 3, - }) + }, + ) assert tunable.span == 4 assert tunable.cardinality == 5 expected = [1, 2, 3, 4, 5] @@ -37,7 +38,8 @@ def test_tunable_float_size_props() -> None: "type": "float", "range": [1.5, 5], "default": 3, - }) + }, + ) assert tunable.span == 3.5 assert tunable.cardinality == np.inf assert tunable.quantized_values is None @@ -52,7 +54,8 @@ def test_tunable_categorical_size_props() -> None: "type": "categorical", "values": ["a", "b", "c"], "default": "a", - }) + }, + ) with pytest.raises(AssertionError): _ = tunable.span assert tunable.cardinality == 3 @@ -65,12 +68,8 @@ def test_tunable_quantized_int_size_props() -> None: """Test quantized tunable int size properties.""" tunable = Tunable( name="test", - config={ - "type": "int", - "range": [100, 1000], - "default": 100, - "quantization": 100 - }) + config={"type": "int", "range": [100, 1000], "default": 100, "quantization": 100}, + ) assert tunable.span == 900 assert tunable.cardinality == 10 expected = [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000] @@ -81,13 +80,8 @@ def test_tunable_quantized_int_size_props() -> None: def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties.""" tunable = Tunable( - name="test", - config={ - "type": "float", - "range": [0, 1], - "default": 0, - "quantization": .1 - }) + name="test", config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1} + ) assert tunable.span == 1 assert tunable.cardinality == 11 expected = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py index 8d214c051b..ccf76d07c8 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_comparison_test.py @@ -22,7 +22,7 @@ def test_tunable_int_value_lt(tunable_int: Tunable) -> None: def test_tunable_int_name_lt(tunable_int: Tunable) -> None: """Tests that the __lt__ operator works as expected.""" tunable_int_2 = tunable_int.copy() - tunable_int_2._name = "aaa" # pylint: disable=protected-access + tunable_int_2._name = "aaa" # pylint: disable=protected-access assert tunable_int_2 < tunable_int @@ -30,7 +30,8 @@ def test_tunable_categorical_value_lt(tunable_categorical: Tunable) -> None: """Tests that the __lt__ operator works as expected.""" tunable_categorical_2 = tunable_categorical.copy() new_value = [ - x for x in tunable_categorical.categories + x + for x in tunable_categorical.categories if x != tunable_categorical.category and x is not None ][0] assert tunable_categorical.category is not None @@ -49,7 +50,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_dog = Tunable( name="same-name", @@ -57,7 +58,7 @@ def test_tunable_categorical_lt_null() -> None: "type": "categorical", "values": [None, "doggo"], "default": None, - } + }, ) assert tunable_dog < tunable_cat @@ -70,7 +71,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "categorical", "values": ["floof", "fuzz"], "default": "floof", - } + }, ) tunable_int = Tunable( name="same-name", @@ -78,7 +79,7 @@ def test_tunable_lt_same_name_different_type() -> None: "type": "int", "range": [1, 3], "default": 2, - } + }, ) assert tunable_cat < tunable_int @@ -87,7 +88,7 @@ def test_tunable_lt_different_object(tunable_int: Tunable) -> None: """Tests that the __lt__ operator works as expected.""" assert (tunable_int < "foo") is False with pytest.raises(TypeError): - assert "foo" < tunable_int # type: ignore[operator] + assert "foo" < tunable_int # type: ignore[operator] def test_tunable_group_ne_object(tunable_groups: TunableGroups) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py index 410404d66d..7403841f8d 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_definition_test.py @@ -14,7 +14,7 @@ def test_tunable_name() -> None: """Check that tunable name is valid.""" with pytest.raises(ValueError): # ! characters are currently disallowed in tunable names - Tunable(name='test!tunable', config={"type": "float", "range": [0, 1], "default": 0}) + Tunable(name="test!tunable", config={"type": "float", "range": [0, 1], "default": 0}) def test_categorical_required_params() -> None: @@ -28,7 +28,7 @@ def test_categorical_required_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights() -> None: @@ -42,7 +42,7 @@ def test_categorical_weights() -> None: } """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.weights == [25, 25, 50] @@ -58,7 +58,7 @@ def test_categorical_weights_wrong_count() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_weights_wrong_values() -> None: @@ -73,7 +73,7 @@ def test_categorical_weights_wrong_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_wrong_params() -> None: @@ -88,7 +88,7 @@ def test_categorical_wrong_params() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_disallow_special_values() -> None: @@ -103,51 +103,63 @@ def test_categorical_disallow_special_values() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_categorical_tunable_disallow_repeats() -> None: """Disallow duplicate values in categorical tunables.""" with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "foo"], - "default": "foo", - }) + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "foo"], + "default": "foo", + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_disallow_null_default(tunable_type: TunableValueTypeName) -> None: """Disallow null values as default for numerical tunables.""" with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": None, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": None, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_disallow_out_of_range(tunable_type: TunableValueTypeName) -> None: """Disallow out of range values as default for numerical tunables.""" with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "default": 11, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "default": 11, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_tunable_wrong_params(tunable_type: TunableValueTypeName) -> None: """Disallow values param for numerical tunables.""" with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config={ - "type": tunable_type, - "range": [0, 10], - "values": ["foo", "bar"], - "default": 0, - }) + Tunable( + name=f"test_{tunable_type}", + config={ + "type": tunable_type, + "range": [0, 10], + "values": ["foo", "bar"], + "default": 0, + }, + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -162,7 +174,7 @@ def test_numerical_tunable_required_params(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -177,7 +189,7 @@ def test_numerical_tunable_invalid_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(AssertionError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -192,7 +204,7 @@ def test_numerical_tunable_reversed_range(tunable_type: TunableValueTypeName) -> """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name=f'test_{tunable_type}', config=config) + Tunable(name=f"test_{tunable_type}", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -209,7 +221,7 @@ def test_numerical_weights(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [0] assert tunable.weights == [0.1] assert tunable.range_weight == 0.9 @@ -227,7 +239,7 @@ def test_numerical_quantization(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.quantization == 10 assert not tunable.is_log @@ -244,7 +256,7 @@ def test_numerical_log(tunable_type: TunableValueTypeName) -> None: }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.is_log @@ -261,7 +273,7 @@ def test_numerical_weights_no_specials(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -280,7 +292,7 @@ def test_numerical_weights_non_normalized(tunable_type: TunableValueTypeName) -> }} """ config = json.loads(json_config) - tunable = Tunable(name='test', config=config) + tunable = Tunable(name="test", config=config) assert tunable.special == [-1, 0] assert tunable.weights == [0, 10] # Zero weights are ok assert tunable.range_weight == 90 @@ -301,7 +313,7 @@ def test_numerical_weights_wrong_count(tunable_type: TunableValueTypeName) -> No """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -318,7 +330,7 @@ def test_numerical_weights_no_range_weight(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -335,7 +347,7 @@ def test_numerical_range_weight_no_weights(tunable_type: TunableValueTypeName) - """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -351,7 +363,7 @@ def test_numerical_range_weight_no_specials(tunable_type: TunableValueTypeName) """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -369,7 +381,7 @@ def test_numerical_weights_wrong_values(tunable_type: TunableValueTypeName) -> N """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) @pytest.mark.parametrize("tunable_type", ["int", "float"]) @@ -385,7 +397,7 @@ def test_numerical_quantization_wrong(tunable_type: TunableValueTypeName) -> Non """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test', config=config) + Tunable(name="test", config=config) def test_bad_type() -> None: @@ -399,4 +411,4 @@ def test_bad_type() -> None: """ config = json.loads(json_config) with pytest.raises(ValueError): - Tunable(name='test_bad_type', config=config) + Tunable(name="test_bad_type", config=config) diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py index 68c560b1cd..54f08e1709 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_distributions_test.py @@ -13,27 +13,29 @@ def test_categorical_distribution() -> None: """Try to instantiate a categorical tunable with distribution specified.""" with pytest.raises(ValueError): - Tunable(name='test', config={ - "type": "categorical", - "values": ["foo", "bar", "baz"], - "distribution": { - "type": "uniform" + Tunable( + name="test", + config={ + "type": "categorical", + "values": ["foo", "bar", "baz"], + "distribution": {"type": "uniform"}, + "default": "foo", }, - "default": "foo" - }) + ) @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> None: """Create a numeric Tunable with explicit uniform distribution.""" - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "uniform" + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "uniform"}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.is_numerical assert tunable.distribution == "uniform" assert not tunable.distribution_params @@ -42,18 +44,15 @@ def test_numerical_distribution_uniform(tunable_type: TunableValueTypeName) -> N @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> None: """Create a numeric Tunable with explicit Gaussian distribution specified.""" - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "normal", - "params": { - "mu": 0, - "sigma": 1.0 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "normal", "params": {"mu": 0, "sigma": 1.0}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "normal" assert tunable.distribution_params == {"mu": 0, "sigma": 1.0} @@ -61,18 +60,15 @@ def test_numerical_distribution_normal(tunable_type: TunableValueTypeName) -> No @pytest.mark.parametrize("tunable_type", ["int", "float"]) def test_numerical_distribution_beta(tunable_type: TunableValueTypeName) -> None: """Create a numeric Tunable with explicit Beta distribution specified.""" - tunable = Tunable(name="test", config={ - "type": tunable_type, - "range": [0, 10], - "distribution": { - "type": "beta", - "params": { - "alpha": 2, - "beta": 5 - } + tunable = Tunable( + name="test", + config={ + "type": tunable_type, + "range": [0, 10], + "distribution": {"type": "beta", "params": {"alpha": 2, "beta": 5}}, + "default": 0, }, - "default": 0 - }) + ) assert tunable.distribution == "beta" assert tunable.distribution_params == {"alpha": 2, "beta": 5} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index eee8a47e3c..6e4d9c3658 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -8,7 +8,9 @@ from mlos_bench.tunables.tunable_groups import TunableGroups -def test_tunable_group_indexing(tunable_groups: TunableGroups, tunable_categorical: Tunable) -> None: +def test_tunable_group_indexing( + tunable_groups: TunableGroups, tunable_categorical: Tunable +) -> None: """Check that various types of indexing work for the tunable group.""" # Check that the "in" operator works. assert tunable_categorical in tunable_groups diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py index 274b4d6a43..c44fbfc866 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_subgroup_test.py @@ -12,4 +12,4 @@ def test_tunable_group_subgroup(tunable_groups: TunableGroups) -> None: parameters. """ tunables = tunable_groups.subgroup(["provision"]) - assert tunables.get_param_values() == {'vmSize': 'Standard_B4ms'} + assert tunables.get_param_values() == {"vmSize": "Standard_B4ms"} diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index dfd4b4c610..97b8ea8c41 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -35,35 +35,37 @@ @pytest.mark.parametrize("param_type", ["int", "float"]) -@pytest.mark.parametrize("distr_name,distr_params", [ - ("normal", {"mu": 0.0, "sigma": 1.0}), - ("beta", {"alpha": 2, "beta": 5}), - ("uniform", {}), -]) -def test_convert_numerical_distributions(param_type: str, - distr_name: DistributionName, - distr_params: dict) -> None: +@pytest.mark.parametrize( + "distr_name,distr_params", + [ + ("normal", {"mu": 0.0, "sigma": 1.0}), + ("beta", {"alpha": 2, "beta": 5}), + ("uniform", {}), + ], +) +def test_convert_numerical_distributions( + param_type: str, distr_name: DistributionName, distr_params: dict +) -> None: """Convert a numerical Tunable with explicit distribution to ConfigSpace.""" tunable_name = "x" - tunable_groups = TunableGroups({ - "tunable_group": { - "cost": 1, - "params": { - tunable_name: { - "type": param_type, - "range": [0, 100], - "special": [-1, 0], - "special_weights": [0.1, 0.2], - "range_weight": 0.7, - "distribution": { - "type": distr_name, - "params": distr_params - }, - "default": 0 - } + tunable_groups = TunableGroups( + { + "tunable_group": { + "cost": 1, + "params": { + tunable_name: { + "type": param_type, + "range": [0, 100], + "special": [-1, 0], + "special_weights": [0.1, 0.2], + "range_weight": 0.7, + "distribution": {"type": distr_name, "params": distr_params}, + "default": 0, + } + }, } } - }) + ) (tunable, _group) = tunable_groups.get_tunable(tunable_name) assert tunable.distribution == distr_name @@ -79,5 +81,5 @@ def test_convert_numerical_distributions(param_type: str, cs_param = space[tunable_name] assert isinstance(cs_param, _CS_HYPERPARAMETER[param_type, distr_name]) - for (key, val) in distr_params.items(): + for key, val in distr_params.items(): assert getattr(cs_param, key) == val diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 7936277ec7..5350e5e4eb 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -36,17 +36,23 @@ def configuration_space() -> ConfigurationSpace: configuration_space : ConfigurationSpace A new ConfigurationSpace object for testing. """ - (kernel_sched_migration_cost_ns_special, - kernel_sched_migration_cost_ns_type) = special_param_names("kernel_sched_migration_cost_ns") - - spaces = ConfigurationSpace(space={ - "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], - "idle": ["halt", "mwait", "noidle"], - "kernel_sched_migration_cost_ns": (0, 500000), - kernel_sched_migration_cost_ns_special: [-1, 0], - kernel_sched_migration_cost_ns_type: [TunableValueKind.SPECIAL, TunableValueKind.RANGE], - "kernel_sched_latency_ns": (0, 1000000000), - }) + (kernel_sched_migration_cost_ns_special, kernel_sched_migration_cost_ns_type) = ( + special_param_names("kernel_sched_migration_cost_ns") + ) + + spaces = ConfigurationSpace( + space={ + "vmSize": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + "idle": ["halt", "mwait", "noidle"], + "kernel_sched_migration_cost_ns": (0, 500000), + kernel_sched_migration_cost_ns_special: [-1, 0], + kernel_sched_migration_cost_ns_type: [ + TunableValueKind.SPECIAL, + TunableValueKind.RANGE, + ], + "kernel_sched_latency_ns": (0, 1000000000), + } + ) # NOTE: FLAML requires distribution to be uniform spaces["vmSize"].default_value = "Standard_B4ms" @@ -58,18 +64,25 @@ def configuration_space() -> ConfigurationSpace: spaces[kernel_sched_migration_cost_ns_type].probabilities = (0.5, 0.5) spaces["kernel_sched_latency_ns"].default_value = 2000000 - spaces.add_condition(EqualsCondition( - spaces[kernel_sched_migration_cost_ns_special], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.SPECIAL)) - spaces.add_condition(EqualsCondition( - spaces["kernel_sched_migration_cost_ns"], - spaces[kernel_sched_migration_cost_ns_type], TunableValueKind.RANGE)) + spaces.add_condition( + EqualsCondition( + spaces[kernel_sched_migration_cost_ns_special], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.SPECIAL, + ) + ) + spaces.add_condition( + EqualsCondition( + spaces["kernel_sched_migration_cost_ns"], + spaces[kernel_sched_migration_cost_ns_type], + TunableValueKind.RANGE, + ) + ) return spaces -def _cmp_tunable_hyperparameter_categorical( - tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_categorical(tunable: Tunable, space: ConfigurationSpace) -> None: """Check if categorical Tunable and ConfigSpace Hyperparameter actually match.""" param = space[tunable.name] assert isinstance(param, CategoricalHyperparameter) @@ -77,8 +90,7 @@ def _cmp_tunable_hyperparameter_categorical( assert param.default_value == tunable.value -def _cmp_tunable_hyperparameter_numerical( - tunable: Tunable, space: ConfigurationSpace) -> None: +def _cmp_tunable_hyperparameter_numerical(tunable: Tunable, space: ConfigurationSpace) -> None: """Check if integer Tunable and ConfigSpace Hyperparameter actually match.""" param = space[tunable.name] assert isinstance(param, (UniformIntegerHyperparameter, UniformFloatHyperparameter)) @@ -119,12 +131,13 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non Make sure that the corresponding Tunable and Hyperparameter objects match. """ space = tunable_groups_to_configspace(tunable_groups) - for (tunable, _group) in tunable_groups: + for tunable, _group in tunable_groups: _CMP_FUNC[tunable.type](tunable, space) def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace) -> None: + tunable_groups: TunableGroups, configuration_space: ConfigurationSpace +) -> None: """Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. """ diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py index 5893e9440a..05f29a9064 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_assign_test.py @@ -18,12 +18,14 @@ def test_tunables_assign_unknown_param(tunable_groups: TunableGroups) -> None: TunableGroups object. """ with pytest.raises(KeyError): - tunable_groups.assign({ - "vmSize": "Standard_B2ms", - "idle": "mwait", - "UnknownParam_1": 1, - "UnknownParam_2": "invalid-value" - }) + tunable_groups.assign( + { + "vmSize": "Standard_B2ms", + "idle": "mwait", + "UnknownParam_1": 1, + "UnknownParam_2": "invalid-value", + } + ) def test_tunables_assign_categorical(tunable_categorical: Tunable) -> None: @@ -85,14 +87,14 @@ def test_tunable_assign_float_to_numerical_value(tunable_float: Tunable) -> None def test_tunable_assign_str_to_int(tunable_int: Tunable) -> None: """Check str to int coercion.""" tunable_int.value = "10" - assert tunable_int.value == 10 # type: ignore[comparison-overlap] + assert tunable_int.value == 10 # type: ignore[comparison-overlap] assert not tunable_int.is_special def test_tunable_assign_str_to_float(tunable_float: Tunable) -> None: """Check str to float coercion.""" tunable_float.value = "0.5" - assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] + assert tunable_float.value == 0.5 # type: ignore[comparison-overlap] assert not tunable_float.is_special @@ -120,12 +122,12 @@ def test_tunable_assign_null_to_categorical() -> None: } """ config = json.loads(json_config) - categorical_tunable = Tunable(name='categorical_test', config=config) + categorical_tunable = Tunable(name="categorical_test", config=config) assert categorical_tunable assert categorical_tunable.category == "foo" categorical_tunable.value = None assert categorical_tunable.value is None - assert categorical_tunable.value != 'None' + assert categorical_tunable.value != "None" assert categorical_tunable.category is None @@ -134,7 +136,7 @@ def test_tunable_assign_null_to_int(tunable_int: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_int.value = None with pytest.raises((TypeError, AssertionError)): - tunable_int.numerical_value = None # type: ignore[assignment] + tunable_int.numerical_value = None # type: ignore[assignment] def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: @@ -142,7 +144,7 @@ def test_tunable_assign_null_to_float(tunable_float: Tunable) -> None: with pytest.raises((TypeError, AssertionError)): tunable_float.value = None with pytest.raises((TypeError, AssertionError)): - tunable_float.numerical_value = None # type: ignore[assignment] + tunable_float.numerical_value = None # type: ignore[assignment] def test_tunable_assign_special(tunable_int: Tunable) -> None: diff --git a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py index 1f909a63e7..61514c605b 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunables_str_test.py @@ -15,42 +15,44 @@ def test_tunable_groups_str(tunable_groups: TunableGroups) -> None: covariant group. """ # Same as `tunable_groups` (defined in the `conftest.py` file), but in different order: - tunables_other = TunableGroups({ - "kernel": { - "cost": 1, - "params": { - "kernel_sched_latency_ns": { - "type": "int", - "default": 2000000, - "range": [0, 1000000000] + tunables_other = TunableGroups( + { + "kernel": { + "cost": 1, + "params": { + "kernel_sched_latency_ns": { + "type": "int", + "default": 2000000, + "range": [0, 1000000000], + }, + "kernel_sched_migration_cost_ns": { + "type": "int", + "default": -1, + "range": [0, 500000], + "special": [-1], + }, }, - "kernel_sched_migration_cost_ns": { - "type": "int", - "default": -1, - "range": [0, 500000], - "special": [-1] - } - } - }, - "boot": { - "cost": 300, - "params": { - "idle": { - "type": "categorical", - "default": "halt", - "values": ["halt", "mwait", "noidle"] - } - } - }, - "provision": { - "cost": 1000, - "params": { - "vmSize": { - "type": "categorical", - "default": "Standard_B4ms", - "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"] - } - } - }, - }) + }, + "boot": { + "cost": 300, + "params": { + "idle": { + "type": "categorical", + "default": "halt", + "values": ["halt", "mwait", "noidle"], + } + }, + }, + "provision": { + "cost": 1000, + "params": { + "vmSize": { + "type": "categorical", + "default": "Standard_B4ms", + "values": ["Standard_B2s", "Standard_B2ms", "Standard_B4ms"], + } + }, + }, + } + ) assert str(tunable_groups) == str(tunables_other) diff --git a/mlos_bench/mlos_bench/tunables/__init__.py b/mlos_bench/mlos_bench/tunables/__init__.py index 58106a606e..c5f49e9202 100644 --- a/mlos_bench/mlos_bench/tunables/__init__.py +++ b/mlos_bench/mlos_bench/tunables/__init__.py @@ -8,7 +8,7 @@ from mlos_bench.tunables.tunable_groups import TunableGroups __all__ = [ - 'Tunable', - 'TunableValue', - 'TunableGroups', + "Tunable", + "TunableValue", + "TunableGroups", ] diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index 3eba2cb9db..1468ce5545 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -92,10 +92,12 @@ def __eq__(self, other: object) -> bool: return False # TODO: May need to provide logic to relax the equality check on the # tunables (e.g. "compatible" vs. "equal"). - return (self._name == other._name and - self._cost == other._cost and - self._is_updated == other._is_updated and - self._tunables == other._tunables) + return ( + self._name == other._name + and self._cost == other._cost + and self._is_updated == other._is_updated + and self._tunables == other._tunables + ) def equals_defaults(self, other: "CovariantTunableGroup") -> bool: """ @@ -233,7 +235,11 @@ def __contains__(self, tunable: Union[str, Tunable]) -> bool: def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._is_updated |= self.get_tunable(tunable).update(value) return value diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 7bda45e49e..9be5ea9f37 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -95,7 +95,7 @@ def __init__(self, name: str, config: TunableDict): config : dict Python dict that represents a Tunable (e.g., deserialized from JSON) """ - if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema + if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema raise ValueError(f"Invalid name of the tunable: {name}") self._name = name self._type: TunableValueTypeName = config["type"] # required @@ -190,10 +190,16 @@ def _sanity_check_numerical(self) -> None: raise ValueError(f"Number of quantization points is <= 1: {self}") if self.dtype == float: if not isinstance(self._quantization, (float, int)): - raise ValueError(f"Quantization of a float param should be a float or int: {self}") + raise ValueError( + f"Quantization of a float param should be a float or int: {self}" + ) if self._quantization <= 0: raise ValueError(f"Number of quantization points is <= 0: {self}") - if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: + if self._distribution is not None and self._distribution not in { + "uniform", + "normal", + "beta", + }: raise ValueError(f"Invalid distribution: {self}") if self._distribution_params and self._distribution is None: raise ValueError(f"Must specify the distribution: {self}") @@ -218,7 +224,9 @@ def __repr__(self) -> str: """ # TODO? Add weights, specials, quantization, distribution? if self.is_categorical: - return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" + return ( + f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" + ) return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" def __eq__(self, other: object) -> bool: @@ -239,12 +247,12 @@ def __eq__(self, other: object) -> bool: if not isinstance(other, Tunable): return False return bool( - self._name == other._name and - self._type == other._type and - self._current_value == other._current_value + self._name == other._name + and self._type == other._type + and self._current_value == other._current_value ) - def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements + def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements """ Compare the two Tunable objects. We mostly need this to create a canonical list of tunable objects when hashing a TunableGroup. @@ -318,18 +326,21 @@ def value(self, value: TunableValue) -> TunableValue: assert value is not None coerced_value = self.dtype(value) except Exception: - _LOG.error("Impossible conversion: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Impossible conversion: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: - _LOG.error("Loss of precision: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Loss of precision: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): - _LOG.error("Invalid assignment: %s %s <- %s %s", - self._type, self._name, type(value), value) + _LOG.error( + "Invalid assignment: %s %s <- %s %s", self._type, self._name, type(value), value + ) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") self._current_value = coerced_value @@ -387,10 +398,10 @@ def in_range(self, value: Union[int, float, str, None]) -> bool: categorical or None. """ return ( - isinstance(value, (float, int)) and - self.is_numerical and - self._range is not None and - bool(self._range[0] <= value <= self._range[1]) + isinstance(value, (float, int)) + and self.is_numerical + and self._range is not None + and bool(self._range[0] <= value <= self._range[1]) ) @property @@ -600,10 +611,12 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: # Be sure to return python types instead of numpy types. cardinality = self.cardinality assert isinstance(cardinality, int) - return (float(x) for x in np.linspace(start=num_range[0], - stop=num_range[1], - num=cardinality, - endpoint=True)) + return ( + float(x) + for x in np.linspace( + start=num_range[0], stop=num_range[1], num=cardinality, endpoint=True + ) + ) assert self.type == "int", f"Unhandled tunable type: {self}" return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index bc56f20f45..684d15f120 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -26,9 +26,11 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = {} # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = ( + {} + ) # Index (Tunable id -> CovariantTunableGroup) self._tunable_groups: Dict[str, CovariantTunableGroup] = {} - for (name, group_config) in config.items(): + for name, group_config in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) def __bool__(self) -> bool: @@ -77,11 +79,15 @@ def _add_group(self, group: CovariantTunableGroup) -> None: ---------- group : CovariantTunableGroup """ - assert group.name not in self._tunable_groups, f"Duplicate covariant tunable group name {group.name} in {self}" + assert ( + group.name not in self._tunable_groups + ), f"Duplicate covariant tunable group name {group.name} in {self}" self._tunable_groups[group.name] = group for tunable in group.get_tunables(): if tunable.name in self._index: - raise ValueError(f"Duplicate Tunable {tunable.name} from group {group.name} in {self}") + raise ValueError( + f"Duplicate Tunable {tunable.name} from group {group.name} in {self}" + ) self._index[tunable.name] = group def merge(self, tunables: "TunableGroups") -> "TunableGroups": @@ -115,8 +121,10 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": # Check that there's no overlap in the tunables. # But allow for differing current values. if not self._tunable_groups[group.name].equals_defaults(group): - raise ValueError(f"Overlapping covariant tunable group name {group.name} " + - "in {self._tunable_groups[group.name]} and {tunables}") + raise ValueError( + f"Overlapping covariant tunable group name {group.name} " + + "in {self._tunable_groups[group.name]} and {tunables}" + ) return self def __repr__(self) -> str: @@ -128,10 +136,15 @@ def __repr__(self) -> str: string : str A human-readable version of the TunableGroups. """ - return "{ " + ", ".join( - f"{group.name}::{tunable}" - for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) - for tunable in sorted(group._tunables.values())) + " }" + return ( + "{ " + + ", ".join( + f"{group.name}::{tunable}" + for group in sorted(self._tunable_groups.values(), key=lambda g: (-g.cost, g.name)) + for tunable in sorted(group._tunables.values()) + ) + + " }" + ) def __contains__(self, tunable: Union[str, Tunable]) -> bool: """Checks if the given name/tunable is in this tunable group.""" @@ -143,11 +156,15 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: name: str = tunable.name if isinstance(tunable, Tunable) else tunable return self._index[name][name] - def __setitem__(self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable]) -> TunableValue: + def __setitem__( + self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + ) -> TunableValue: """Update the current value of a single tunable parameter.""" # Use double index to make sure we set the is_updated flag of the group name: str = tunable.name if isinstance(tunable, Tunable) else tunable - value: TunableValue = tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + value: TunableValue = ( + tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value + ) self._index[name][name] = value return self._index[name][name] @@ -222,8 +239,11 @@ def subgroup(self, group_names: Iterable[str]) -> "TunableGroups": tunables._add_group(self._tunable_groups[name]) return tunables - def get_param_values(self, group_names: Optional[Iterable[str]] = None, - into_params: Optional[Dict[str, TunableValue]] = None) -> Dict[str, TunableValue]: + def get_param_values( + self, + group_names: Optional[Iterable[str]] = None, + into_params: Optional[Dict[str, TunableValue]] = None, + ) -> Dict[str, TunableValue]: """ Get the current values of the tunables that belong to the specified covariance groups. @@ -263,8 +283,10 @@ def is_updated(self, group_names: Optional[Iterable[str]] = None) -> bool: is_updated : bool True if any of the specified tunable groups has been updated, False otherwise. """ - return any(self._tunable_groups[name].is_updated() - for name in (group_names or self.get_covariant_group_names())) + return any( + self._tunable_groups[name].is_updated() + for name in (group_names or self.get_covariant_group_names()) + ) def is_defaults(self) -> bool: """ @@ -291,7 +313,7 @@ def restore_defaults(self, group_names: Optional[Iterable[str]] = None) -> "Tuna self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].restore_defaults() return self @@ -309,7 +331,7 @@ def reset(self, group_names: Optional[Iterable[str]] = None) -> "TunableGroups": self : TunableGroups Self-reference for chaining. """ - for name in (group_names or self.get_covariant_group_names()): + for name in group_names or self.get_covariant_group_names(): self._tunable_groups[name].reset_is_updated() return self diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index d516eb5337..37170b06c0 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -69,8 +69,9 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> return dest -def merge_parameters(*, dest: dict, source: Optional[dict] = None, - required_keys: Optional[Iterable[str]] = None) -> dict: +def merge_parameters( + *, dest: dict, source: Optional[dict] = None, required_keys: Optional[Iterable[str]] = None +) -> dict: """ Merge the source config dict into the destination config. Pick from the source configs *ONLY* the keys that are already present in the destination config. @@ -129,8 +130,9 @@ def path_join(*args: str, abs_path: bool = False) -> str: return os.path.normpath(path).replace("\\", "/") -def prepare_class_load(config: dict, - global_config: Optional[Dict[str, Any]] = None) -> Tuple[str, Dict[str, Any]]: +def prepare_class_load( + config: dict, global_config: Optional[Dict[str, Any]] = None +) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -152,8 +154,9 @@ def prepare_class_load(config: dict, merge_parameters(dest=class_config, source=global_config) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Instantiating: %s with config:\n%s", - class_name, json.dumps(class_config, indent=2)) + _LOG.debug( + "Instantiating: %s with config:\n%s", class_name, json.dumps(class_config, indent=2) + ) return (class_name, class_config) @@ -184,8 +187,9 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". -def instantiate_from_config(base_class: Type[BaseTypeVar], class_name: str, - *args: Any, **kwargs: Any) -> BaseTypeVar: +def instantiate_from_config( + base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any +) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -235,7 +239,8 @@ def check_required_params(config: Mapping[str, Any], required_params: Iterable[s if missing_params: raise ValueError( "The following parameters must be provided in the configuration" - + f" or as command line arguments: {missing_params}") + + f" or as command line arguments: {missing_params}" + ) def get_git_info(path: str = __file__) -> Tuple[str, str, str]: @@ -254,11 +259,14 @@ def get_git_info(path: str = __file__) -> Tuple[str, str, str]: """ dirname = os.path.dirname(path) git_repo = subprocess.check_output( - ["git", "-C", dirname, "remote", "get-url", "origin"], text=True).strip() + ["git", "-C", dirname, "remote", "get-url", "origin"], text=True + ).strip() git_commit = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "HEAD"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "HEAD"], text=True + ).strip() git_root = subprocess.check_output( - ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True).strip() + ["git", "-C", dirname, "rev-parse", "--show-toplevel"], text=True + ).strip() _LOG.debug("Current git branch: %s %s", git_repo, git_commit) rel_path = os.path.relpath(os.path.abspath(path), os.path.abspath(git_root)) return (git_repo, git_commit, rel_path.replace("\\", "/")) @@ -352,7 +360,9 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> raise ValueError(f"Invalid origin: {origin}") -def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal["utc", "local"]) -> Optional[datetime]: +def utcify_nullable_timestamp( + timestamp: Optional[datetime], *, origin: Literal["utc", "local"] +) -> Optional[datetime]: """A nullable version of utcify_timestamp.""" return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None @@ -362,7 +372,9 @@ def utcify_nullable_timestamp(timestamp: Optional[datetime], *, origin: Literal[ _MIN_TS = datetime(2024, 1, 1, 0, 0, 0, tzinfo=pytz.UTC) -def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "local"]) -> pandas.Series: +def datetime_parser( + datetime_col: pandas.Series, *, origin: Literal["utc", "local"] +) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. @@ -396,7 +408,7 @@ def datetime_parser(datetime_col: pandas.Series, *, origin: Literal["utc", "loca new_datetime_col = new_datetime_col.dt.tz_localize(tzinfo) assert new_datetime_col.dt.tz is not None # And convert it to UTC. - new_datetime_col = new_datetime_col.dt.tz_convert('UTC') + new_datetime_col = new_datetime_col.dt.tz_convert("UTC") if new_datetime_col.isna().any(): raise ValueError(f"Invalid date format in the data: {datetime_col}") if new_datetime_col.le(_MIN_TS).any(): diff --git a/mlos_bench/mlos_bench/version.py b/mlos_bench/mlos_bench/version.py index 520192b647..ab6ab85d2d 100644 --- a/mlos_bench/mlos_bench/version.py +++ b/mlos_bench/mlos_bench/version.py @@ -5,7 +5,7 @@ """Version number for the mlos_bench package.""" # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index 9e00657dfa..f86b7a9663 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -19,15 +19,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -45,62 +46,68 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } -extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass +extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass # Additional tools for extra functionality. - 'azure': ['azure-storage-file-share', 'azure-identity', 'azure-keyvault'], - 'ssh': ['asyncssh'], - 'storage-sql-duckdb': ['sqlalchemy', 'duckdb_engine'], - 'storage-sql-mysql': ['sqlalchemy', 'mysql-connector-python'], - 'storage-sql-postgres': ['sqlalchemy', 'psycopg2'], - 'storage-sql-sqlite': ['sqlalchemy'], # sqlite3 comes with python, so we don't need to install it. + "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], + "ssh": ["asyncssh"], + "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], + "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], + "storage-sql-postgres": ["sqlalchemy", "psycopg2"], + "storage-sql-sqlite": [ + "sqlalchemy" + ], # sqlite3 comes with python, so we don't need to install it. # Transitive extra_requires from mlos-core. - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac'], + "flaml": ["flaml[blendsearch]"], + "smac": ["smac"], } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', - 'pytest-lazy-fixtures', - 'pytest-docker', - 'fasteners', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", + "pytest-lazy-fixtures", + "pytest-docker", + "fasteners", ] setup( version=VERSION, install_requires=[ - 'mlos-core==' + VERSION, - 'requests', - 'json5', - 'jsonschema>=4.18.0', 'referencing>=0.29.1', + "mlos-core==" + VERSION, + "requests", + "json5", + "jsonschema>=4.18.0", + "referencing>=0.29.1", 'importlib_resources;python_version<"3.10"', - ] + extra_requires['storage-sql-sqlite'], # NOTE: For now sqlite is a fallback storage backend, so we always install it. + ] + + extra_requires[ + "storage-sql-sqlite" + ], # NOTE: For now sqlite is a fallback storage backend, so we always install it. extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_bench'), + **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_bench"), ) diff --git a/mlos_core/mlos_core/optimizers/__init__.py b/mlos_core/mlos_core/optimizers/__init__.py index c72600be02..396bd5e212 100644 --- a/mlos_core/mlos_core/optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/__init__.py @@ -16,12 +16,12 @@ from mlos_core.spaces.adapters import SpaceAdapterFactory, SpaceAdapterType __all__ = [ - 'SpaceAdapterType', - 'OptimizerFactory', - 'BaseOptimizer', - 'RandomOptimizer', - 'FlamlOptimizer', - 'SmacOptimizer', + "SpaceAdapterType", + "OptimizerFactory", + "BaseOptimizer", + "RandomOptimizer", + "FlamlOptimizer", + "SmacOptimizer", ] @@ -43,7 +43,7 @@ class OptimizerType(Enum): # ConcreteOptimizer = TypeVar('ConcreteOptimizer', *[member.value for member in OptimizerType]) # To address this, we add a test for complete coverage of the enum. ConcreteOptimizer = TypeVar( - 'ConcreteOptimizer', + "ConcreteOptimizer", RandomOptimizer, FlamlOptimizer, SmacOptimizer, @@ -58,13 +58,15 @@ class OptimizerFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, - optimizer_kwargs: Optional[dict] = None, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteOptimizer: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + optimizer_type: OptimizerType = DEFAULT_OPTIMIZER_TYPE, + optimizer_kwargs: Optional[dict] = None, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteOptimizer: # type: ignore[type-var] """ Create a new optimizer instance, given the parameter space, optimizer type, and potential optimizer options. @@ -105,7 +107,7 @@ def create(*, parameter_space=parameter_space, optimization_targets=optimization_targets, space_adapter=space_adapter, - **optimizer_kwargs + **optimizer_kwargs, ) return optimizer diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py index d4b7294f32..1a4fea7188 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/__init__.py @@ -10,6 +10,6 @@ from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer __all__ = [ - 'BaseBayesianOptimizer', - 'SmacOptimizer', + "BaseBayesianOptimizer", + "SmacOptimizer", ] diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index de333be46e..11669d4d79 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -17,8 +17,9 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): """Abstract base class defining the interface for Bayesian optimization.""" @abstractmethod - def surrogate_predict(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """ Obtain a prediction from this Bayesian optimizer's surrogate model for the given configuration(s). @@ -26,16 +27,18 @@ def surrogate_predict(self, *, configs: pd.DataFrame, Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def acquisition_function(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: """ Invokes the acquisition function from this Bayesian optimizer for the given configuration. @@ -43,9 +46,10 @@ def acquisition_function(self, *, configs: pd.DataFrame, Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index e86d868cdb..7833ab31eb 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -28,19 +28,22 @@ class SmacOptimizer(BaseBayesianOptimizer): """Wrapper class for SMAC based Bayesian optimization.""" - def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - seed: Optional[int] = 0, - run_name: Optional[str] = None, - output_directory: Optional[str] = None, - max_trials: int = 100, - n_random_init: Optional[int] = None, - max_ratio: Optional[float] = None, - use_default_config: bool = False, - n_random_probability: float = 0.1): + def __init__( + self, + *, # pylint: disable=too-many-locals,too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + seed: Optional[int] = 0, + run_name: Optional[str] = None, + output_directory: Optional[str] = None, + max_trials: int = 100, + n_random_init: Optional[int] = None, + max_ratio: Optional[float] = None, + use_default_config: bool = False, + n_random_probability: float = 0.1, + ): """ Instantiate a new SMAC optimizer wrapper. @@ -60,18 +63,21 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments seed : Optional[int] By default SMAC uses a known seed (0) to keep results reproducible. - However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. + However, if a `None` seed is explicitly provided, we let a random seed + be produced by SMAC. run_name : Optional[str] Name of this run. This is used to easily distinguish across different runs. If set to `None` (default), SMAC will generate a hash from metadata. output_directory : Optional[str] - The directory where SMAC output will saved. If set to `None` (default), a temporary dir will be used. + The directory where SMAC output will saved. If set to `None` (default), + a temporary dir will be used. max_trials : int Maximum number of trials (i.e., function evaluations) to be run. Defaults to 100. - Note that modifying this value directly affects the value of `n_random_init`, if latter is set to `None`. + Note that modifying this value directly affects the value of + `n_random_init`, if latter is set to `None`. n_random_init : Optional[int] Number of points evaluated at start to bootstrap the optimizer. @@ -115,7 +121,8 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments self.trial_info_map: Dict[ConfigSpace.Configuration, TrialInfo] = {} # The default when not specified is to use a known seed (0) to keep results reproducible. - # However, if a `None` seed is explicitly provided, we let a random seed be produced by SMAC. + # However, if a `None` seed is explicitly provided, we let a random seed be + # produced by SMAC. # https://automl.github.io/SMAC3/main/api/smac.scenario.html#smac.scenario.Scenario seed = -1 if seed is None else seed @@ -123,7 +130,9 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory( + ignore_cleanup_errors=True + ) # Argument added in Python 3.10 except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -145,8 +154,12 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments seed=seed or -1, # if -1, SMAC will generate a random seed internally n_workers=1, # Use a single thread for evaluating trials ) - intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier(scenario, max_config_calls=1) - config_selector: ConfigSelector = Optimizer_Smac.get_config_selector(scenario, retrain_after=1) + intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( + scenario, max_config_calls=1 + ) + config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( + scenario, retrain_after=1 + ) # TODO: When bulk registering prior configs to rewarm the optimizer, # there is a way to inform SMAC's initial design that we have @@ -157,39 +170,45 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments # See Also: #488 initial_design_args: Dict[str, Union[list, int, float, Scenario]] = { - 'scenario': scenario, + "scenario": scenario, # Workaround a bug in SMAC that sets a default arg to a mutable # value that can cause issues when multiple optimizers are # instantiated with the use_default_config option within the same # process that use different ConfigSpaces so that the second # receives the default config from both as an additional config. - 'additional_configs': [] + "additional_configs": [], } if n_random_init is not None: - initial_design_args['n_configs'] = n_random_init + initial_design_args["n_configs"] = n_random_init if n_random_init > 0.25 * max_trials and max_ratio is None: warning( - 'Number of random initial configs (%d) is ' + - 'greater than 25%% of max_trials (%d). ' + - 'Consider setting max_ratio to avoid SMAC overriding n_random_init.', + "Number of random initial configs (%d) is " + + "greater than 25%% of max_trials (%d). " + + "Consider setting max_ratio to avoid SMAC overriding n_random_init.", n_random_init, max_trials, ) if max_ratio is not None: assert isinstance(max_ratio, float) and 0.0 <= max_ratio <= 1.0 - initial_design_args['max_ratio'] = max_ratio + initial_design_args["max_ratio"] = max_ratio # Use the default InitialDesign from SMAC. # (currently SBOL instead of LatinHypercube due to better uniformity # for initial sampling which results in lower overall samples required) - initial_design = Optimizer_Smac.get_initial_design(**initial_design_args) # type: ignore[arg-type] - # initial_design = LatinHypercubeInitialDesign(**initial_design_args) # type: ignore[arg-type] + initial_design = Optimizer_Smac.get_initial_design( + **initial_design_args, # type: ignore[arg-type] + ) + # initial_design = LatinHypercubeInitialDesign( + # **initial_design_args, # type: ignore[arg-type] + # ) # Workaround a bug in SMAC that doesn't pass the seed to the random # design when generated a random_design for itself via the # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 - random_design = ProbabilityRandomDesign(probability=n_random_probability, seed=scenario.seed) + random_design = ProbabilityRandomDesign( + probability=n_random_probability, seed=scenario.seed + ) self.base_optimizer = Optimizer_Smac( scenario, @@ -199,7 +218,8 @@ def __init__(self, *, # pylint: disable=too-many-locals,too-many-arguments random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights), + scenario, objective_weights=self._objective_weights + ), overwrite=True, logging_level=False, # Use the existing logger ) @@ -214,7 +234,8 @@ def n_random_init(self) -> int: Gets the number of random samples to use to initialize the optimizer's search space sampling. - Note: This may not be equal to the value passed to the initializer, due to logic present in the SMAC. + Note: This may not be equal to the value passed to the initializer, due to + logic present in the SMAC. See Also: max_ratio Returns @@ -240,22 +261,31 @@ def _dummy_target_func(config: ConfigSpace.Configuration, seed: int = 0) -> None seed : int Random seed to use for the target function. Not actually used. """ - # NOTE: Providing a target function when using the ask-and-tell interface is an imperfection of the API - # -- this planned to be fixed in some future release: https://github.com/automl/SMAC3/issues/946 - raise RuntimeError('This function should never be called.') - - def _register(self, *, configs: pd.DataFrame, - scores: pd.DataFrame, context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + # NOTE: Providing a target function when using the ask-and-tell interface is + # an imperfection of the API -- this is planned to be fixed in some future + # release: https://github.com/automl/SMAC3/issues/946 + raise RuntimeError("This function should never be called.") + + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Registers the given configs and scores. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. scores : pd.DataFrame - Scores from running the configs. The index is the same as the index of the configs. + Scores from running the configs. The index is the same as the index of + the configs. context : pd.DataFrame Not Yet Implemented. @@ -273,17 +303,23 @@ def _register(self, *, configs: pd.DataFrame, warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) # Register each trial (one-by-one) - for (config, (_i, score)) in zip(self._to_configspace_configs(configs=configs), scores.iterrows()): - # Retrieve previously generated TrialInfo (returned by .ask()) or create new TrialInfo instance + for config, (_i, score) in zip( + self._to_configspace_configs(configs=configs), scores.iterrows() + ): + # Retrieve previously generated TrialInfo (returned by .ask()) or create + # new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed)) + config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) + ) value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) self.base_optimizer.tell(info, value, save=False) # Save optimizer once we register all configs self.base_optimizer.optimizer.save() - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. @@ -301,9 +337,8 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr Not yet implemented. """ if TYPE_CHECKING: - from smac.runhistory import ( - TrialInfo, # pylint: disable=import-outside-toplevel,unused-import - ) + # pylint: disable=import-outside-toplevel,unused-import + from smac.runhistory import TrialInfo if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) @@ -313,18 +348,25 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr self.optimizer_parameter_space.check_configuration(trial.config) assert trial.config.config_space == self.optimizer_parameter_space self.trial_info_map[trial.config] = trial - config_df = pd.DataFrame([trial.config], columns=list(self.optimizer_parameter_space.keys())) + config_df = pd.DataFrame( + [trial.config], columns=list(self.optimizer_parameter_space.keys()) + ) return config_df, None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() - def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: - from smac.utils.configspace import ( - convert_configurations_to_array, # pylint: disable=import-outside-toplevel - ) + def surrogate_predict( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: + # pylint: disable=import-outside-toplevel + from smac.utils.configspace import convert_configurations_to_array if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) @@ -334,16 +376,24 @@ def surrogate_predict(self, *, configs: pd.DataFrame, context: Optional[pd.DataF # pylint: disable=protected-access if len(self._observations) <= self.base_optimizer._initial_design._n_configs: raise RuntimeError( - 'Surrogate model can make predictions *only* after all initial points have been evaluated ' + - f'{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}') + "Surrogate model can make predictions *only* after " + "all initial points have been evaluated " + f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" + ) if self.base_optimizer._config_selector._model is None: - raise RuntimeError('Surrogate model is not yet trained') + raise RuntimeError("Surrogate model is not yet trained") - config_array: npt.NDArray = convert_configurations_to_array(self._to_configspace_configs(configs=configs)) + config_array: npt.NDArray = convert_configurations_to_array( + self._to_configspace_configs(configs=configs) + ) mean_predictions, _ = self.base_optimizer._config_selector._model.predict(config_array) - return mean_predictions.reshape(-1,) + return mean_predictions.reshape( + -1, + ) - def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None) -> npt.NDArray: + def acquisition_function( + self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + ) -> npt.NDArray: if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) if self._space_adapter: @@ -351,13 +401,15 @@ def acquisition_function(self, *, configs: pd.DataFrame, context: Optional[pd.Da # pylint: disable=protected-access if self.base_optimizer._config_selector._acquisition_function is None: - raise RuntimeError('Acquisition function is not yet initialized') + raise RuntimeError("Acquisition function is not yet initialized") cs_configs: list = self._to_configspace_configs(configs=configs) - return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape(-1,) + return self.base_optimizer._config_selector._acquisition_function(cs_configs).reshape( + -1, + ) def cleanup(self) -> None: - if hasattr(self, '_temp_output_directory') and self._temp_output_directory is not None: + if hasattr(self, "_temp_output_directory") and self._temp_output_directory is not None: self._temp_output_directory.cleanup() self._temp_output_directory = None @@ -368,7 +420,8 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. Returns ------- @@ -377,5 +430,5 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. """ return [ ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) - for (_, config) in configs.astype('O').iterrows() + for (_, config) in configs.astype("O").iterrows() ] diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index aaefdbdf3d..958f98e02e 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -26,16 +26,20 @@ class EvaluatedSample(NamedTuple): class FlamlOptimizer(BaseOptimizer): """Wrapper class for FLAML Optimizer: A fast library for AutoML and tuning.""" - # The name of an internal objective attribute that is calculated as a weighted average of the user provided objective metrics. + # The name of an internal objective attribute that is calculated as a weighted + # average of the user provided objective metrics. _METRIC_NAME = "FLAML_score" - def __init__(self, *, # pylint: disable=too-many-arguments - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None, - low_cost_partial_config: Optional[dict] = None, - seed: Optional[int] = None): + def __init__( + self, + *, # pylint: disable=too-many-arguments + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + low_cost_partial_config: Optional[dict] = None, + seed: Optional[int] = None, + ): """ Create an MLOS wrapper for FLAML. @@ -55,10 +59,12 @@ def __init__(self, *, # pylint: disable=too-many-arguments low_cost_partial_config : dict A dictionary from a subset of controlled dimensions to the initial low-cost values. - More info: https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune + More info: + https://microsoft.github.io/FLAML/docs/FAQ#about-low_cost_partial_config-in-tune seed : Optional[int] - If provided, calls np.random.seed() with the provided value to set the seed globally at init. + If provided, calls np.random.seed() with the provided value to set the + seed globally at init. """ super().__init__( parameter_space=parameter_space, @@ -78,21 +84,30 @@ def __init__(self, *, # pylint: disable=too-many-arguments configspace_to_flaml_space, ) - self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space(self.optimizer_parameter_space) + self.flaml_parameter_space: Dict[str, FlamlDomain] = configspace_to_flaml_space( + self.optimizer_parameter_space + ) self.low_cost_partial_config = low_cost_partial_config self.evaluated_samples: Dict[ConfigSpace.Configuration, EvaluatedSample] = {} self._suggested_config: Optional[dict] - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Registers the given configs and scores. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -108,9 +123,10 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, if metadata is not None: warn(f"Not Implemented: Ignoring metadata {list(metadata.columns)}", UserWarning) - for (_, config), (_, score) in zip(configs.astype('O').iterrows(), scores.iterrows()): + for (_, config), (_, score) in zip(configs.astype("O").iterrows(), scores.iterrows()): cs_config: ConfigSpace.Configuration = ConfigSpace.Configuration( - self.optimizer_parameter_space, values=config.to_dict()) + self.optimizer_parameter_space, values=config.to_dict() + ) if cs_config in self.evaluated_samples: warn(f"Configuration {config} was already registered", UserWarning) self.evaluated_samples[cs_config] = EvaluatedSample( @@ -118,7 +134,9 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, score=float(np.average(score.astype(float), weights=self._objective_weights)), ) - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. @@ -142,27 +160,35 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr config: dict = self._get_next_config() return pd.DataFrame(config, index=[0]), None - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() def _target_function(self, config: dict) -> Union[dict, None]: """ Configuration evaluation function called by FLAML optimizer. - FLAML may suggest the same configuration multiple times (due to its warm-start mechanism). - Once FLAML suggests an unseen configuration, we store it, and stop the optimization process. + FLAML may suggest the same configuration multiple times (due to its + warm-start mechanism). Once FLAML suggests an unseen configuration, we + store it, and stop the optimization process. Parameters ---------- config: dict Next configuration to be evaluated, as suggested by FLAML. - This config is stored internally and is returned to user, via `.suggest()` method. + This config is stored internally and is returned to user, via + `.suggest()` method. Returns ------- result: Union[dict, None] - Dictionary with a single key, `FLAML_score`, if config already evaluated; `None` otherwise. + Dictionary with a single key, `FLAML_score`, if config already + evaluated; `None` otherwise. """ cs_config = normalize_config(self.optimizer_parameter_space, config) if cs_config in self.evaluated_samples: @@ -176,10 +202,13 @@ def _get_next_config(self) -> dict: Warm-starts a new instance of FLAML, and returns a recommended, unseen new configuration. - Since FLAML does not provide an ask-and-tell interface, we need to create a new instance of FLAML - each time we get asked for a new suggestion. This is suboptimal performance-wise, but works. - To do so, we use any previously evaluated configs to bootstrap FLAML (i.e., warm-start). - For more info: https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start + Since FLAML does not provide an ask-and-tell interface, we need to create a + new instance of FLAML each time we get asked for a new suggestion. This is + suboptimal performance-wise, but works. + To do so, we use any previously evaluated configs to bootstrap FLAML (i.e., + warm-start). + For more info: + https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function#warm-start Returns ------- @@ -201,16 +230,14 @@ def _get_next_config(self) -> dict: dict(normalize_config(self.optimizer_parameter_space, conf)) for conf in self.evaluated_samples ] - evaluated_rewards = [ - s.score for s in self.evaluated_samples.values() - ] + evaluated_rewards = [s.score for s in self.evaluated_samples.values()] # Warm start FLAML optimizer self._suggested_config = None tune.run( self._target_function, config=self.flaml_parameter_space, - mode='min', + mode="min", metric=self._METRIC_NAME, points_to_evaluate=points_to_evaluate, evaluated_rewards=evaluated_rewards, @@ -219,6 +246,6 @@ def _get_next_config(self) -> dict: verbose=0, ) if self._suggested_config is None: - raise RuntimeError('FLAML did not produce a suggestion') + raise RuntimeError("FLAML did not produce a suggestion") return self._suggested_config # type: ignore[unreachable] diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index d9b37910b5..ddd4a466db 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -20,11 +20,14 @@ class BaseOptimizer(metaclass=ABCMeta): """Optimizer abstract base class defining the basic interface.""" - def __init__(self, *, - parameter_space: ConfigSpace.ConfigurationSpace, - optimization_targets: List[str], - objective_weights: Optional[List[float]] = None, - space_adapter: Optional[BaseSpaceAdapter] = None): + def __init__( + self, + *, + parameter_space: ConfigSpace.ConfigurationSpace, + optimization_targets: List[str], + objective_weights: Optional[List[float]] = None, + space_adapter: Optional[BaseSpaceAdapter] = None, + ): """ Create a new instance of the base optimizer. @@ -40,8 +43,9 @@ def __init__(self, *, The space adapter class to employ for parameter space transformations. """ self.parameter_space: ConfigSpace.ConfigurationSpace = parameter_space - self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = \ + self.optimizer_parameter_space: ConfigSpace.ConfigurationSpace = ( parameter_space if space_adapter is None else space_adapter.target_parameter_space + ) if space_adapter is not None and space_adapter.orig_parameter_space != parameter_space: raise ValueError("Given parameter space differs from the one given to space adapter") @@ -64,8 +68,14 @@ def space_adapter(self) -> Optional[BaseSpaceAdapter]: """Get the space adapter instance (if any).""" return self._space_adapter - def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Wrapper method, which employs the space adapter (if any), before registering the configs and scores. @@ -73,7 +83,8 @@ def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -85,46 +96,56 @@ def register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, """ # Do some input validation. assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(scores.columns) == set(self._optimization_targets), \ - "Mismatched optimization targets." - assert self._has_context is None or self._has_context ^ (context is None), \ - "Context must always be added or never be added." - assert len(configs) == len(scores), \ - "Mismatched number of configs and scores." + assert set(scores.columns) == set( + self._optimization_targets + ), "Mismatched optimization targets." + assert self._has_context is None or self._has_context ^ ( + context is None + ), "Context must always be added or never be added." + assert len(configs) == len(scores), "Mismatched number of configs and scores." if context is not None: - assert len(configs) == len(context), \ - "Mismatched number of configs and context." - assert configs.shape[1] == len(self.parameter_space.values()), \ - "Mismatched configuration shape." + assert len(configs) == len(context), "Mismatched number of configs and context." + assert configs.shape[1] == len( + self.parameter_space.values() + ), "Mismatched configuration shape." self._observations.append((configs, scores, context)) self._has_context = context is not None if self._space_adapter: configs = self._space_adapter.inverse_transform(configs) - assert configs.shape[1] == len(self.optimizer_parameter_space.values()), \ - "Mismatched configuration shape after inverse transform." + assert configs.shape[1] == len( + self.optimizer_parameter_space.values() + ), "Mismatched configuration shape after inverse transform." return self._register(configs=configs, scores=scores, context=context) @abstractmethod - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Registers the given configs and scores. Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. context : pd.DataFrame Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover - def suggest(self, *, context: Optional[pd.DataFrame] = None, - defaults: bool = False) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def suggest( + self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new configuration. @@ -149,18 +170,23 @@ def suggest(self, *, context: Optional[pd.DataFrame] = None, configuration = self.space_adapter.inverse_transform(configuration) else: configuration, metadata = self._suggest(context=context) - assert len(configuration) == 1, \ - "Suggest must return a single configuration." - assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), \ - "Optimizer suggested a configuration that does not match the expected parameter space." + assert len(configuration) == 1, "Suggest must return a single configuration." + assert set(configuration.columns).issubset(set(self.optimizer_parameter_space)), ( + "Optimizer suggested a configuration that does " + "not match the expected parameter space." + ) if self._space_adapter: configuration = self._space_adapter.transform(configuration) - assert set(configuration.columns).issubset(set(self.parameter_space)), \ - "Space adapter produced a configuration that does not match the expected parameter space." + assert set(configuration.columns).issubset(set(self.parameter_space)), ( + "Space adapter produced a configuration that does " + "not match the expected parameter space." + ) return configuration, metadata @abstractmethod - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. @@ -177,12 +203,16 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr metadata : Optional[pd.DataFrame] The metadata associated with the given configuration used for evaluations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, - metadata: Optional[pd.DataFrame] = None) -> None: + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Registers the given configs as "pending". That is it say, it has been suggested by the optimizer, and an experiment trial has been started. This can be useful @@ -191,13 +221,14 @@ def register_pending(self, *, configs: pd.DataFrame, Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. context : pd.DataFrame Not Yet Implemented. metadata : Optional[pd.DataFrame] Not Yet Implemented. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ @@ -212,11 +243,17 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data raise ValueError("No observations registered yet.") configs = pd.concat([config for config, _, _ in self._observations]).reset_index(drop=True) scores = pd.concat([score for _, score, _ in self._observations]).reset_index(drop=True) - contexts = pd.concat([pd.DataFrame() if context is None else context - for _, _, context in self._observations]).reset_index(drop=True) + contexts = pd.concat( + [ + pd.DataFrame() if context is None else context + for _, _, context in self._observations + ] + ).reset_index(drop=True) return (configs, scores, contexts if len(contexts.columns) > 0 else None) - def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: + def get_best_observations( + self, *, n_max: int = 1 + ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, context). Default is N=1. The columns are ordered in ASCENDING order of the @@ -237,8 +274,7 @@ def get_best_observations(self, *, n_max: int = 1) -> Tuple[pd.DataFrame, pd.Dat raise ValueError("No observations registered yet.") (configs, scores, contexts) = self.get_observations() idx = scores.nsmallest(n_max, columns=self._optimization_targets, keep="first").index - return (configs.loc[idx], scores.loc[idx], - None if contexts is None else contexts.loc[idx]) + return (configs.loc[idx], scores.loc[idx], None if contexts is None else contexts.loc[idx]) def cleanup(self) -> None: """ @@ -257,7 +293,7 @@ def _from_1hot(self, *, config: npt.NDArray) -> pd.DataFrame: j = 0 for param in self.optimizer_parameter_space.values(): if isinstance(param, ConfigSpace.CategoricalHyperparameter): - for (offset, val) in enumerate(param.choices): + for offset, val in enumerate(param.choices): if config[i][j + offset] == 1: df_dict[param.name].append(val) break diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index 7f83b8e086..ddee68f345 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -23,8 +23,14 @@ class RandomOptimizer(BaseOptimizer): The parameter space to optimize. """ - def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + def _register( + self, + *, + configs: pd.DataFrame, + scores: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: """ Registers the given configs and scores. @@ -33,7 +39,8 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, Parameters ---------- configs : pd.DataFrame - Dataframe of configs / parameters. The columns are parameter names and the rows are the configs. + Dataframe of configs / parameters. The columns are parameter names and + the rows are the configs. scores : pd.DataFrame Scores from running the configs. The index is the same as the index of the configs. @@ -50,7 +57,9 @@ def _register(self, *, configs: pd.DataFrame, scores: pd.DataFrame, warn(f"Not Implemented: Ignoring context {list(metadata.columns)}", UserWarning) # should we pop them from self.pending_observations? - def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: + def _suggest( + self, *, context: Optional[pd.DataFrame] = None + ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. @@ -72,9 +81,17 @@ def _suggest(self, *, context: Optional[pd.DataFrame] = None) -> Tuple[pd.DataFr if context is not None: # not sure how that works here? warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) - return pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), None - - def register_pending(self, *, configs: pd.DataFrame, - context: Optional[pd.DataFrame] = None, metadata: Optional[pd.DataFrame] = None) -> None: + return ( + pd.DataFrame(dict(self.optimizer_parameter_space.sample_configuration()), index=[0]), + None, + ) + + def register_pending( + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, + metadata: Optional[pd.DataFrame] = None, + ) -> None: raise NotImplementedError() # self._pending_observations.append((configs, context)) diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 8618707f9a..3187e32bc6 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -13,8 +13,8 @@ from mlos_core.spaces.adapters.llamatune import LlamaTuneAdapter __all__ = [ - 'IdentityAdapter', - 'LlamaTuneAdapter', + "IdentityAdapter", + "LlamaTuneAdapter", ] @@ -30,10 +30,13 @@ class SpaceAdapterType(Enum): # To make mypy happy, we need to define a type variable for each optimizer type. # https://github.com/python/mypy/issues/12952 -# ConcreteSpaceAdapter = TypeVar('ConcreteSpaceAdapter', *[member.value for member in SpaceAdapterType]) +# ConcreteSpaceAdapter = TypeVar( +# "ConcreteSpaceAdapter", +# *[member.value for member in SpaceAdapterType], +# ) # To address this, we add a test for complete coverage of the enum. ConcreteSpaceAdapter = TypeVar( - 'ConcreteSpaceAdapter', + "ConcreteSpaceAdapter", IdentityAdapter, LlamaTuneAdapter, ) @@ -45,10 +48,12 @@ class SpaceAdapterFactory: # pylint: disable=too-few-public-methods @staticmethod - def create(*, - parameter_space: ConfigSpace.ConfigurationSpace, - space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, - space_adapter_kwargs: Optional[dict] = None) -> ConcreteSpaceAdapter: # type: ignore[type-var] + def create( + *, + parameter_space: ConfigSpace.ConfigurationSpace, + space_adapter_type: SpaceAdapterType = SpaceAdapterType.IDENTITY, + space_adapter_kwargs: Optional[dict] = None, + ) -> ConcreteSpaceAdapter: # type: ignore[type-var] """ Create a new space adapter instance, given the parameter space and potential space adapter options. @@ -73,8 +78,7 @@ def create(*, space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, - **space_adapter_kwargs + orig_parameter_space=parameter_space, **space_adapter_kwargs ) return space_adapter diff --git a/mlos_core/mlos_core/spaces/adapters/adapter.py b/mlos_core/mlos_core/spaces/adapters/adapter.py index f28ab694a4..2d48a14c31 100644 --- a/mlos_core/mlos_core/spaces/adapters/adapter.py +++ b/mlos_core/mlos_core/spaces/adapters/adapter.py @@ -41,7 +41,7 @@ def orig_parameter_space(self) -> ConfigSpace.ConfigurationSpace: @abstractmethod def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: """Target parameter space that is fed to the underlying optimizer.""" - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: @@ -53,7 +53,8 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Parameters ---------- configuration : pd.DataFrame - Pandas dataframe with a single row. Column names are the parameter names of the target parameter space. + Pandas dataframe with a single row. Column names are the parameter names + of the target parameter space. Returns ------- @@ -61,7 +62,7 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: Pandas dataframe with a single row, containing the translated configuration. Column names are the parameter names of the original parameter space. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover @abstractmethod def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: @@ -75,12 +76,14 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: ---------- configurations : pd.DataFrame Dataframe of configurations / parameters, which belong to the original parameter space. - The columns are the parameter names the original parameter space and the rows are the configurations. + The columns are the parameter names the original parameter space and the + rows are the configurations. Returns ------- configurations : pd.DataFrame Dataframe of the translated configurations / parameters. - The columns are the parameter names of the target parameter space and the rows are the configurations. + The columns are the parameter names of the target parameter space and + the rows are the configurations. """ - pass # pylint: disable=unnecessary-pass # pragma: no cover + pass # pylint: disable=unnecessary-pass # pragma: no cover diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index 8a416d40ab..e304c0dd50 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -17,7 +17,7 @@ from mlos_core.util import normalize_config -class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes +class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance-attributes """Implementation of LlamaTune, a set of parameter space transformation techniques, aimed at improving the sample-efficiency of the underlying optimizer. """ @@ -27,7 +27,7 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- HeSBO projection. """ - DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = .2 + DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE = 0.2 """Default percentage of bias for each special parameter value.""" DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM = 10000 @@ -35,12 +35,15 @@ class LlamaTuneAdapter(BaseSpaceAdapter): # pylint: disable=too-many-instance- discretization is used. """ - def __init__(self, *, - orig_parameter_space: ConfigSpace.ConfigurationSpace, - num_low_dims: int = DEFAULT_NUM_LOW_DIMS, - special_param_values: Optional[dict] = None, - max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, - use_approximate_reverse_mapping: bool = False): + def __init__( + self, + *, + orig_parameter_space: ConfigSpace.ConfigurationSpace, + num_low_dims: int = DEFAULT_NUM_LOW_DIMS, + special_param_values: Optional[dict] = None, + max_unique_values_per_param: Optional[int] = DEFAULT_MAX_UNIQUE_VALUES_PER_PARAM, + use_approximate_reverse_mapping: bool = False, + ): """ Create a space adapter that employs LlamaTune's techniques. @@ -59,7 +62,10 @@ def __init__(self, *, super().__init__(orig_parameter_space=orig_parameter_space) if num_low_dims >= len(orig_parameter_space): - raise ValueError("Number of target config space dimensions should be less than those of original config space.") + raise ValueError( + "Number of target config space dimensions should be " + "less than those of original config space." + ) # Validate input special param values dict special_param_values = special_param_values or {} @@ -91,26 +97,35 @@ def target_parameter_space(self) -> ConfigSpace.ConfigurationSpace: def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] - for (_, config) in configurations.astype('O').iterrows(): + for _, config in configurations.astype("O").iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict()) + self.orig_parameter_space, values=config.to_dict() + ) target_config = self._suggested_configs.get(configuration, None) - # NOTE: HeSBO is a non-linear projection method, and does not inherently support inverse projection - # To (partly) support this operation, we keep track of the suggested low-dim point(s) along with the - # respective high-dim point; this way we can retrieve the low-dim point, from its high-dim counterpart. + # NOTE: HeSBO is a non-linear projection method, and does not inherently + # support inverse projection + # To (partly) support this operation, we keep track of the suggested + # low-dim point(s) along with the respective high-dim point; this way we + # can retrieve the low-dim point, from its high-dim counterpart. if target_config is None: - # Inherently it is not supported to register points, which were not suggested by the optimizer. + # Inherently it is not supported to register points, which were not + # suggested by the optimizer. if configuration == self.orig_parameter_space.get_default_configuration(): # Default configuration should always be registerable. pass elif not self._use_approximate_reverse_mapping: - raise ValueError(f"{repr(configuration)}\n" "The above configuration was not suggested by the optimizer. " - "Approximate reverse mapping is currently disabled; thus *only* configurations suggested " - "previously by the optimizer can be registered.") - - # ...yet, we try to support that by implementing an approximate reverse mapping using pseudo-inverse matrix. - if getattr(self, '_pinv_matrix', None) is None: + raise ValueError( + f"{repr(configuration)}\n" + "The above configuration was not suggested by the optimizer. " + "Approximate reverse mapping is currently disabled; " + "thus *only* configurations suggested " + "previously by the optimizer can be registered." + ) + + # ...yet, we try to support that by implementing an approximate + # reverse mapping using pseudo-inverse matrix. + if getattr(self, "_pinv_matrix", None) is None: self._try_generate_approx_inverse_mapping() # Replace NaNs with zeros for inactive hyperparameters @@ -119,19 +134,27 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: # NOTE: applying special value biasing is not possible vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) - target_config = ConfigSpace.Configuration(self.target_parameter_space, vector=target_config_vector) + target_config = ConfigSpace.Configuration( + self.target_parameter_space, vector=target_config_vector + ) target_configurations.append(target_config) - return pd.DataFrame(target_configurations, columns=list(self.target_parameter_space.keys())) + return pd.DataFrame( + target_configurations, columns=list(self.target_parameter_space.keys()) + ) def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: if len(configuration) != 1: - raise ValueError("Configuration dataframe must contain exactly 1 row. " - f"Found {len(configuration)} rows.") + raise ValueError( + "Configuration dataframe must contain exactly 1 row. " + f"Found {len(configuration)} rows." + ) target_values_dict = configuration.iloc[0].to_dict() - target_configuration = ConfigSpace.Configuration(self.target_parameter_space, values=target_values_dict) + target_configuration = ConfigSpace.Configuration( + self.target_parameter_space, values=target_values_dict + ) orig_values_dict = self._transform(target_values_dict) orig_configuration = normalize_config(self.orig_parameter_space, orig_values_dict) @@ -139,9 +162,13 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: # Add to inverse dictionary -- needed for registering the performance later self._suggested_configs[orig_configuration] = target_configuration - return pd.DataFrame([list(orig_configuration.values())], columns=list(orig_configuration.keys())) + return pd.DataFrame( + [list(orig_configuration.values())], columns=list(orig_configuration.keys()) + ) - def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_param: Optional[int]) -> None: + def _construct_low_dim_space( + self, num_low_dims: int, max_unique_values_per_param: Optional[int] + ) -> None: """ Constructs the low-dimensional parameter (potentially discretized) search space. @@ -158,19 +185,25 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para q_scaler = None if max_unique_values_per_param is None: hyperparameters = [ - ConfigSpace.UniformFloatHyperparameter(name=f'dim_{idx}', lower=-1, upper=1) + ConfigSpace.UniformFloatHyperparameter(name=f"dim_{idx}", lower=-1, upper=1) for idx in range(num_low_dims) ] else: - # Currently supported optimizers do not support defining a discretized space (like ConfigSpace does using `q` kwarg). - # Thus, to support space discretization, we define the low-dimensional space using integer hyperparameters. - # We also employ a scaler, which scales suggested values to [-1, 1] range, used by HeSBO projection. + # Currently supported optimizers do not support defining a discretized + # space (like ConfigSpace does using `q` kwarg). + # Thus, to support space discretization, we define the low-dimensional + # space using integer hyperparameters. + # We also employ a scaler, which scales suggested values to [-1, 1] + # range, used by HeSBO projection. hyperparameters = [ - ConfigSpace.UniformIntegerHyperparameter(name=f'dim_{idx}', lower=1, upper=max_unique_values_per_param) + ConfigSpace.UniformIntegerHyperparameter( + name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param + ) for idx in range(num_low_dims) ] - # Initialize quantized values scaler: from [0, max_unique_values_per_param] to (-1, 1) range + # Initialize quantized values scaler: + # from [0, max_unique_values_per_param] to (-1, 1) range q_scaler = MinMaxScaler(feature_range=(-1, 1)) ones_vector = np.ones(num_low_dims) max_value_vector = ones_vector * max_unique_values_per_param @@ -180,7 +213,9 @@ def _construct_low_dim_space(self, num_low_dims: int, max_unique_values_per_para # Construct low-dimensional parameter search space config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) - config_space.random = self._random_state # use same random state as in original parameter space + config_space.random = ( + self._random_state + ) # use same random state as in original parameter space config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -220,10 +255,10 @@ def _transform(self, configuration: dict) -> dict: # Clip value to force it to fall in [0, 1] # NOTE: HeSBO projection ensures that theoretically but due to # floating point ops nuances this is not always guaranteed - value = max(0., min(1., norm_value)) # pylint: disable=redefined-loop-name + value = max(0.0, min(1.0, norm_value)) # pylint: disable=redefined-loop-name if isinstance(param, ConfigSpace.CategoricalHyperparameter): - index = int(value * len(param.choices)) # truncate integer part + index = int(value * len(param.choices)) # truncate integer part index = max(0, min(len(param.choices) - 1, index)) # NOTE: potential rounding here would be unfair to first & last values orig_value = param.choices[index] @@ -231,16 +266,20 @@ def _transform(self, configuration: dict) -> dict: if param.name in self._special_param_values_dict: value = self._special_param_value_scaler(param, value) - orig_value = param._transform(value) # pylint: disable=protected-access + orig_value = param._transform(value) # pylint: disable=protected-access orig_value = max(param.lower, min(param.upper, orig_value)) else: - raise NotImplementedError("Only Categorical, Integer, and Float hyperparameters are currently supported.") + raise NotImplementedError( + "Only Categorical, Integer, and Float hyperparameters are currently supported." + ) original_config[param.name] = orig_value return original_config - def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float) -> float: + def _special_param_value_scaler( + self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float + ) -> float: """ Biases the special value(s) of this parameter, by shifting the normalized `input_value` towards those. @@ -261,7 +300,7 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara special_values_list = self._special_param_values_dict[param.name] # Check if input value corresponds to some special value - perc_sum = 0. + perc_sum = 0.0 ret: float for special_value, biasing_perc in special_values_list: perc_sum += biasing_perc @@ -270,8 +309,10 @@ def _special_param_value_scaler(self, param: ConfigSpace.UniformIntegerHyperpara return ret # Scale input value uniformly to non-special values - ret = param._inverse_transform( # pylint: disable=protected-access - param._transform_scalar((input_value - perc_sum) / (1 - perc_sum))) # pylint: disable=protected-access + # pylint: disable=protected-access + ret = param._inverse_transform( + param._transform_scalar((input_value - perc_sum) / (1 - perc_sum)) + ) return ret # pylint: disable=too-complex,too-many-branches @@ -301,8 +342,10 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non hyperparameter = self.orig_parameter_space[param] if not isinstance(hyperparameter, ConfigSpace.UniformIntegerHyperparameter): - raise NotImplementedError(error_prefix + f"Parameter '{param}' is not supported. " - "Only Integer Hyperparameters are currently supported.") + raise NotImplementedError( + error_prefix + f"Parameter '{param}' is not supported. " + "Only Integer Hyperparameters are currently supported." + ) if isinstance(value, int): # User specifies a single special value -- default biasing percentage is used @@ -313,45 +356,78 @@ def _validate_special_param_values(self, special_param_values_dict: dict) -> Non elif isinstance(value, list) and value: if all(isinstance(t, int) for t in value): # User specifies list of special values - tuple_list = [(v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value] - elif all(isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value): - # User specifies list of tuples; each tuple defines the special value and the biasing percentage + tuple_list = [ + (v, self.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE) for v in value + ] + elif all( + isinstance(t, tuple) and [type(v) for v in t] == [int, float] for t in value + ): + # User specifies list of tuples; each tuple defines the special + # value and the biasing percentage tuple_list = value else: - raise ValueError(error_prefix + f"Invalid format in value list for parameter '{param}'. " - f"Special value list should contain either integers, or (special value, biasing %) tuples.") + raise ValueError( + error_prefix + f"Invalid format in value list for parameter '{param}'. " + f"Special value list should contain either integers, " + "or (special value, biasing %) tuples." + ) else: - raise ValueError(error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " - "an int, a (int, float) tuple, a list of integers, or a list of (int, float) tuples.") + raise ValueError( + error_prefix + f"Invalid format for parameter '{param}'. Dict value should be " + "an int, a (int, float) tuple, a list of integers, " + "or a list of (int, float) tuples." + ) # Are user-specified special values valid? if not all(hyperparameter.lower <= v <= hyperparameter.upper for v, _ in tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are outside of parameter '{param}' value domain.") + raise ValueError( + error_prefix + + "One (or more) special values are outside of parameter " + + f"'{param}' value domain." + ) # Are user-provided special values unique? if len(set(v for v, _ in tuple_list)) != len(tuple_list): - raise ValueError(error_prefix + f"One (or more) special values are defined more than once for parameter '{param}'.") + raise ValueError( + error_prefix + + "One (or more) special values are defined more than once " + + f"for parameter '{param}'." + ) # Are biasing percentages valid? if not all(0 < perc < 1 for _, perc in tuple_list): - raise ValueError(error_prefix + f"One (or more) biasing percentages for parameter '{param}' are invalid: " - "i.e., fall outside (0, 1) range.") + raise ValueError( + error_prefix + + f"One (or more) biasing percentages for parameter '{param}' are invalid: " + "i.e., fall outside (0, 1) range." + ) total_percentage = sum(perc for _, perc in tuple_list) - if total_percentage >= 1.: - raise ValueError(error_prefix + f"Total special values percentage for parameter '{param}' surpass 100%.") + if total_percentage >= 1.0: + raise ValueError( + error_prefix + + f"Total special values percentage for parameter '{param}' surpass 100%." + ) # ... and reasonable? if total_percentage >= 0.5: - warn(f"Total special values percentage for parameter '{param}' exceeds 50%.", UserWarning) + warn( + f"Total special values percentage for parameter '{param}' exceeds 50%.", + UserWarning, + ) sanitized_dict[param] = tuple_list self._special_param_values_dict = sanitized_dict def _try_generate_approx_inverse_mapping(self) -> None: - """Tries to generate an approximate reverse mapping: i.e., from high-dimensional space to the low-dimensional one. - Reverse mapping is generated using the pseudo-inverse matrix, of original HeSBO projection matrix. - This mapping can be potentially used to register configurations that were *not* previously suggested by the optimizer. + """Tries to generate an approximate reverse mapping: + i.e., from high-dimensional space to the low-dimensional one. + + Reverse mapping is generated using the pseudo-inverse matrix, of original + HeSBO projection matrix. + This mapping can be potentially used to register configurations that were + *not* previously suggested by the optimizer. - NOTE: This method is experimental, and there is currently no guarantee that it works as expected. + NOTE: This method is experimental, and there is currently no guarantee that + it works as expected. Raises ------ @@ -362,9 +438,16 @@ def _try_generate_approx_inverse_mapping(self) -> None: pinv, ) - warn("Trying to register a configuration that was not previously suggested by the optimizer. " + - "This inverse configuration transformation is typically not supported. " + - "However, we will try to register this configuration using an *experimental* method.", UserWarning) + warn( + ( + "Trying to register a configuration that was not " + "previously suggested by the optimizer.\n" + "This inverse configuration transformation is typically not supported.\n" + "However, we will try to register this configuration " + "using an *experimental* method." + ), + UserWarning, + ) orig_space_num_dims = len(list(self.orig_parameter_space.values())) target_space_num_dims = len(list(self.target_parameter_space.values())) @@ -378,5 +461,7 @@ def _try_generate_approx_inverse_mapping(self) -> None: try: self._pinv_matrix = pinv(proj_matrix) except LinAlgError as err: - raise RuntimeError(f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}") from err + raise RuntimeError( + f"Unable to generate reverse mapping using pseudo-inverse matrix: {repr(err)}" + ) from err assert self._pinv_matrix.shape == (target_space_num_dims, orig_space_num_dims) diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 4aee0154b6..8b669f98e4 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -25,7 +25,9 @@ FlamlSpace: TypeAlias = Dict[str, flaml.tune.sample.Domain] -def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> Dict[str, FlamlDomain]: +def configspace_to_flaml_space( + config_space: ConfigSpace.ConfigurationSpace, +) -> Dict[str, FlamlDomain]: """ Converts a ConfigSpace.ConfigurationSpace to dict. @@ -49,13 +51,19 @@ def configspace_to_flaml_space(config_space: ConfigSpace.ConfigurationSpace) -> def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + ) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): - return flaml_numeric_type[(type(parameter), parameter.log)](parameter.lower, parameter.upper + 1) + return flaml_numeric_type[(type(parameter), parameter.log)]( + parameter.lower, parameter.upper + 1 + ) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: - raise ValueError("FLAML doesn't support categorical parameters with non-uniform probabilities.") - return flaml.tune.choice(parameter.choices) # TODO: set order? + raise ValueError( + "FLAML doesn't support categorical parameters with non-uniform probabilities." + ) + return flaml.tune.choice(parameter.choices) # TODO: set order? raise ValueError(f"Type of parameter {parameter} ({type(parameter)}) not supported.") return {param.name: _one_parameter_convert(param) for param in config_space.values()} diff --git a/mlos_core/mlos_core/tests/__init__.py b/mlos_core/mlos_core/tests/__init__.py index 6a0962f415..cff9016da7 100644 --- a/mlos_core/mlos_core/tests/__init__.py +++ b/mlos_core/mlos_core/tests/__init__.py @@ -19,7 +19,7 @@ from typing_extensions import TypeAlias -T = TypeVar('T') +T = TypeVar("T") def get_all_submodules(pkg: TypeAlias) -> List[str]: @@ -29,7 +29,9 @@ def get_all_submodules(pkg: TypeAlias) -> List[str]: Useful for dynamically enumerating subclasses. """ submodules = [] - for _, submodule_name, _ in walk_packages(pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None): + for _, submodule_name, _ in walk_packages( + pkg.__path__, prefix=f"{pkg.__name__}.", onerror=lambda x: None + ): submodules.append(submodule_name) return submodules @@ -41,7 +43,8 @@ def _get_all_subclasses(cls: Type[T]) -> Set[Type[T]]: Useful for dynamically enumerating expected test cases. """ return set(cls.__subclasses__()).union( - s for c in cls.__subclasses__() for s in _get_all_subclasses(c)) + s for c in cls.__subclasses__() for s in _get_all_subclasses(c) + ) def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> List[Type[T]]: @@ -57,5 +60,11 @@ def get_all_concrete_subclasses(cls: Type[T], pkg_name: Optional[str] = None) -> pkg = import_module(pkg_name) submodules = get_all_submodules(pkg) assert submodules - return sorted([subclass for subclass in _get_all_subclasses(cls) if not getattr(subclass, "__abstractmethods__", None)], - key=lambda c: (c.__module__, c.__name__)) + return sorted( + [ + subclass + for subclass in _get_all_subclasses(cls) + if not getattr(subclass, "__abstractmethods__", None) + ], + key=lambda c: (c.__module__, c.__name__), + ) diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index e0b094e4d6..68599e176b 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -15,24 +15,27 @@ @pytest.mark.filterwarnings("error:Not Implemented") -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_context_not_implemented_warning(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], - kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_context_not_implemented_warning( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """Make sure we raise warnings for the functionality that has not been implemented yet. """ if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) suggestion, _metadata = optimizer.suggest() - scores = pd.DataFrame({'score': [1]}) + scores = pd.DataFrame({"score": [1]}) context = pd.DataFrame([["something"]]) with pytest.raises(UserWarning): diff --git a/mlos_core/mlos_core/tests/optimizers/conftest.py b/mlos_core/mlos_core/tests/optimizers/conftest.py index 417b917552..fe82ff92bb 100644 --- a/mlos_core/mlos_core/tests/optimizers/conftest.py +++ b/mlos_core/mlos_core/tests/optimizers/conftest.py @@ -14,9 +14,9 @@ def configuration_space() -> CS.ConfigurationSpace: # Start defining a ConfigurationSpace for the Optimizer to search. space = CS.ConfigurationSpace(seed=1234) # Add a continuous input dimension between 0 and 1. - space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) # Add a categorical hyperparameter with 3 possible values. - space.add_hyperparameter(CS.CategoricalHyperparameter(name='y', choices=["a", "b", "c"])) + space.add_hyperparameter(CS.CategoricalHyperparameter(name="y", choices=["a", "b", "c"])) # Add a discrete input dimension between 0 and 10. - space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='z', lower=0, upper=10)) + space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="z", lower=0, upper=10)) return space diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index f9fe07fbf0..da5a3d492a 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -22,11 +22,13 @@ def data_frame() -> pd.DataFrame: The columns are deliberately *not* in alphabetic order. """ - return pd.DataFrame({ - 'y': ['a', 'b', 'c'], - 'x': [0.1, 0.2, 0.3], - 'z': [1, 5, 8], - }) + return pd.DataFrame( + { + "y": ["a", "b", "c"], + "x": [0.1, 0.2, 0.3], + "z": [1, 5, 8], + } + ) @pytest.fixture @@ -36,11 +38,13 @@ def one_hot_data_frame() -> npt.NDArray: The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.1, 1.0, 0.0, 0.0, 1.0], - [0.2, 0.0, 1.0, 0.0, 5.0], - [0.3, 0.0, 0.0, 1.0, 8.0], - ]) + return np.array( + [ + [0.1, 1.0, 0.0, 0.0, 1.0], + [0.2, 0.0, 1.0, 0.0, 5.0], + [0.3, 0.0, 0.0, 1.0, 8.0], + ] + ) @pytest.fixture @@ -50,11 +54,13 @@ def series() -> pd.Series: The columns are deliberately *not* in alphabetic order. """ - return pd.Series({ - 'y': 'b', - 'x': 0.4, - 'z': 3, - }) + return pd.Series( + { + "y": "b", + "x": 0.4, + "z": 3, + } + ) @pytest.fixture @@ -64,9 +70,11 @@ def one_hot_series() -> npt.NDArray: The columns follow the order of the hyperparameters in `configuration_space`. """ - return np.array([ - [0.4, 0.0, 1.0, 0.0, 3], - ]) + return np.array( + [ + [0.4, 0.0, 1.0, 0.0, 3], + ] + ) @pytest.fixture @@ -78,33 +86,34 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: """ return SmacOptimizer( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], ) -def test_to_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_to_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """Toy problem to test one-hot encoding of dataframe.""" assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) -def test_to_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, one_hot_series: npt.NDArray) -> None: +def test_to_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """Toy problem to test one-hot encoding of series.""" assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) -def test_from_1hot_data_frame(optimizer: BaseOptimizer, - data_frame: pd.DataFrame, - one_hot_data_frame: npt.NDArray) -> None: +def test_from_1hot_data_frame( + optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray +) -> None: """Toy problem to test one-hot decoding of dataframe.""" assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() -def test_from_1hot_series(optimizer: BaseOptimizer, - series: pd.Series, - one_hot_series: npt.NDArray) -> None: +def test_from_1hot_series( + optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray +) -> None: """Toy problem to test one-hot decoding of series.""" one_hot_df = optimizer._from_1hot(config=one_hot_series) assert one_hot_df.shape[0] == 1, f"Unexpected number of rows ({one_hot_df.shape[0]} != 1)" @@ -127,15 +136,15 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No assert (series_round_trip.z == series.z).all() -def test_round_trip_reverse_data_frame(optimizer: BaseOptimizer, - one_hot_data_frame: npt.NDArray) -> None: +def test_round_trip_reverse_data_frame( + optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray +) -> None: """Round-trip test for one-hot-decoding and then encoding of a numpy array.""" round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame)) assert round_trip == pytest.approx(one_hot_data_frame) -def test_round_trip_reverse_series(optimizer: BaseOptimizer, - one_hot_series: npt.NDArray) -> None: +def test_round_trip_reverse_series(optimizer: BaseOptimizer, one_hot_series: npt.NDArray) -> None: """Round-trip test for one-hot-decoding and then encoding of a numpy array.""" round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_series)) assert round_trip == pytest.approx(one_hot_series) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 271bfce1d8..c1f743dd03 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -18,33 +18,44 @@ _LOG = logging.getLogger(__name__) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt_wrong_weights(optimizer_class: Type[BaseOptimizer], kwargs: dict) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt_wrong_weights( + optimizer_class: Type[BaseOptimizer], kwargs: dict +) -> None: """Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. """ with pytest.raises(ValueError): optimizer_class( parameter_space=CS.ConfigurationSpace(seed=SEED), - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=[1], - **kwargs + **kwargs, ) -@pytest.mark.parametrize(('objective_weights'), [ - [2, 1], - [0.5, 0.5], - None, -]) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_multi_target_opt(objective_weights: Optional[List[float]], - optimizer_class: Type[BaseOptimizer], - kwargs: dict) -> None: +@pytest.mark.parametrize( + ("objective_weights"), + [ + [2, 1], + [0.5, 0.5], + None, + ], +) +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_multi_target_opt( + objective_weights: Optional[List[float]], optimizer_class: Type[BaseOptimizer], kwargs: dict +) -> None: """Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -52,21 +63,21 @@ def test_multi_target_opt(objective_weights: Optional[List[float]], def objective(point: pd.DataFrame) -> pd.DataFrame: # mix of hyperparameters, optimal is to select the highest possible - return pd.DataFrame({ - "main_score": point.x + point.y, - "other_score": point.x ** 2 + point.y ** 2, - }) + return pd.DataFrame( + { + "main_score": point.x + point.y, + "other_score": point.x**2 + point.y**2, + } + ) input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) optimizer = optimizer_class( parameter_space=input_space, - optimization_targets=['main_score', 'other_score'], + optimization_targets=["main_score", "other_score"], objective_weights=objective_weights, **kwargs, ) @@ -81,27 +92,28 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y'} + assert set(suggestion.columns) == {"x", "y"} # Check suggestion values are the expected dtype assert isinstance(suggestion.x.iloc[0], np.integer) assert isinstance(suggestion.y.iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. observation = objective(suggestion) assert isinstance(observation, pd.DataFrame) - assert set(observation.columns) == {'main_score', 'other_score'} + assert set(observation.columns) == {"main_score", "other_score"} optimizer.register(configs=suggestion, scores=observation) (best_config, best_score, best_context) = optimizer.get_best_observations() assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'main_score', 'other_score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"main_score", "other_score"} assert best_config.shape == (1, 2) assert best_score.shape == (1, 2) @@ -109,7 +121,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'main_score', 'other_score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"main_score", "other_score"} assert all_configs.shape == (max_iterations, 2) assert all_scores.shape == (max_iterations, 2) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index b1c68ad136..7233918673 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -30,18 +30,22 @@ _LOG.setLevel(logging.DEBUG) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_create_optimizer_and_suggest( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """Test that we can create an optimizer and get a suggestion from it.""" if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) assert optimizer is not None @@ -58,30 +62,36 @@ def test_create_optimizer_and_suggest(configuration_space: CS.ConfigurationSpace optimizer.register_pending(configs=suggestion, metadata=metadata) -@pytest.mark.parametrize(('optimizer_class', 'kwargs'), [ - *[(member.value, {}) for member in OptimizerType], -]) -def test_basic_interface_toy_problem(configuration_space: CS.ConfigurationSpace, - optimizer_class: Type[BaseOptimizer], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_class", "kwargs"), + [ + *[(member.value, {}) for member in OptimizerType], + ], +) +def test_basic_interface_toy_problem( + configuration_space: CS.ConfigurationSpace, + optimizer_class: Type[BaseOptimizer], + kwargs: Optional[dict], +) -> None: """Toy problem to test the optimizers.""" # pylint: disable=too-many-locals max_iterations = 20 if kwargs is None: kwargs = {} if optimizer_class == OptimizerType.SMAC.value: - # SMAC sets the initial random samples as a percentage of the max iterations, which defaults to 100. - # To avoid having to train more than 25 model iterations, we set a lower number of max iterations. - kwargs['max_trials'] = max_iterations * 2 + # SMAC sets the initial random samples as a percentage of the max + # iterations, which defaults to 100. + # To avoid having to train more than 25 model iterations, we set a lower + # number of max iterations. + kwargs["max_trials"] = max_iterations * 2 def objective(x: pd.Series) -> pd.DataFrame: - return pd.DataFrame({"score": (6 * x - 2)**2 * np.sin(12 * x - 4)}) + return pd.DataFrame({"score": (6 * x - 2) ** 2 * np.sin(12 * x - 4)}) # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, - optimization_targets=['score'], - **kwargs + parameter_space=configuration_space, optimization_targets=["score"], **kwargs ) with pytest.raises(ValueError, match="No observations"): @@ -94,12 +104,12 @@ def objective(x: pd.Series) -> pd.DataFrame: suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) assert metadata is None or isinstance(metadata, pd.DataFrame) - assert set(suggestion.columns) == {'x', 'y', 'z'} + assert set(suggestion.columns) == {"x", "y", "z"} # check that suggestion is in the space configuration = CS.Configuration(optimizer.parameter_space, suggestion.iloc[0].to_dict()) # Raises an error if outside of configuration space configuration.is_valid_configuration() - observation = objective(suggestion['x']) + observation = objective(suggestion["x"]) assert isinstance(observation, pd.DataFrame) optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -107,8 +117,8 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y', 'z'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y", "z"} + assert set(best_score.columns) == {"score"} assert best_config.shape == (1, 3) assert best_score.shape == (1, 1) assert best_score.score.iloc[0] < -5 @@ -117,12 +127,13 @@ def objective(x: pd.Series) -> pd.DataFrame: assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y', 'z'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y", "z"} + assert set(all_scores.columns) == {"score"} assert all_configs.shape == (20, 3) assert all_scores.shape == (20, 1) - # It would be better to put this into bayesian_optimizer_test but then we'd have to refit the model + # It would be better to put this into bayesian_optimizer_test but then we'd have + # to refit the model if isinstance(optimizer, BaseBayesianOptimizer): pred_best = optimizer.surrogate_predict(configs=best_config) assert pred_best.shape == (1,) @@ -131,38 +142,48 @@ def objective(x: pd.Series) -> pd.DataFrame: assert pred_all.shape == (20,) -@pytest.mark.parametrize(('optimizer_type'), [ - # Enumerate all supported Optimizers - # *[member for member in OptimizerType], - *list(OptimizerType), -]) +@pytest.mark.parametrize( + ("optimizer_type"), + [ + # Enumerate all supported Optimizers + # *[member for member in OptimizerType], + *list(OptimizerType), + ], +) def test_concrete_optimizer_type(optimizer_type: OptimizerType) -> None: """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" - assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] # pylint: disable=no-member - - -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_create_optimizer_with_factory_method(configuration_space: CS.ConfigurationSpace, - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: + # pylint: disable=no-member + assert optimizer_type.value in ConcreteOptimizer.__constraints__ # type: ignore[attr-defined] + + +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_create_optimizer_with_factory_method( + configuration_space: CS.ConfigurationSpace, + optimizer_type: Optional[OptimizerType], + kwargs: Optional[dict], +) -> None: """Test that we can create an optimizer via a factory.""" if kwargs is None: kwargs = {} if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=configuration_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -178,16 +199,22 @@ def test_create_optimizer_with_factory_method(configuration_space: CS.Configurat assert myrepr.startswith(optimizer_type.value.__name__) -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument - (OptimizerType.SMAC, { - # Test with default config. - 'use_default_config': True, - # 'n_random_init': 10, - }), -]) +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ( + OptimizerType.SMAC, + { + # Test with default config. + "use_default_config": True, + # 'n_random_init': 10, + }, + ), + ], +) def test_optimizer_with_llamatune(optimizer_type: OptimizerType, kwargs: Optional[dict]) -> None: """Toy problem to test the optimizers with llamatune space adapter.""" # pylint: disable=too-complex,disable=too-many-statements,disable=too-many-locals @@ -203,8 +230,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=1234) # Add two continuous inputs - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=3)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=3)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=3)) # Initialize an optimizer that uses LlamaTune space adapter space_adapter_kwargs = { @@ -227,7 +254,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: llamatune_optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=llamatune_optimizer_kwargs, space_adapter_type=SpaceAdapterType.LLAMATUNE, @@ -236,7 +263,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Initialize an optimizer that uses the original space optimizer: BaseOptimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=optimizer_kwargs, ) @@ -245,7 +272,7 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: assert optimizer.optimizer_parameter_space != llamatune_optimizer.optimizer_parameter_space llamatune_n_random_init = 0 - opt_n_random_init = int(kwargs.get('n_random_init', 0)) + opt_n_random_init = int(kwargs.get("n_random_init", 0)) if optimizer_type == OptimizerType.SMAC: assert isinstance(optimizer, SmacOptimizer) assert isinstance(llamatune_optimizer, SmacOptimizer) @@ -266,8 +293,10 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() - _x, _y = suggestion['x'].iloc[0], suggestion['y'].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3., rel=1e-3) # optimizer explores 1-dimensional space + _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( + 3.0, rel=1e-3 + ) # optimizer explores 1-dimensional space observation = objective(suggestion) llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -275,28 +304,33 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: best_observation = optimizer.get_best_observations() llamatune_best_observation = llamatune_optimizer.get_best_observations() - for (best_config, best_score, best_context) in (best_observation, llamatune_best_observation): + for best_config, best_score, best_context in (best_observation, llamatune_best_observation): assert isinstance(best_config, pd.DataFrame) assert isinstance(best_score, pd.DataFrame) assert best_context is None - assert set(best_config.columns) == {'x', 'y'} - assert set(best_score.columns) == {'score'} + assert set(best_config.columns) == {"x", "y"} + assert set(best_score.columns) == {"score"} (best_config, best_score, _context) = best_observation (llamatune_best_config, llamatune_best_score, _context) = llamatune_best_observation - # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's one, or close to that - assert best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] or \ - best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + # LlamaTune's optimizer score should better (i.e., lower) than plain optimizer's + # one, or close to that + assert ( + best_score.score.iloc[0] > llamatune_best_score.score.iloc[0] + or best_score.score.iloc[0] + 1e-3 > llamatune_best_score.score.iloc[0] + ) # Retrieve and check all observations - for (all_configs, all_scores, all_contexts) in ( - optimizer.get_observations(), llamatune_optimizer.get_observations()): + for all_configs, all_scores, all_contexts in ( + optimizer.get_observations(), + llamatune_optimizer.get_observations(), + ): assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None - assert set(all_configs.columns) == {'x', 'y'} - assert set(all_scores.columns) == {'score'} + assert set(all_configs.columns) == {"x", "y"} + assert set(all_scores.columns) == {"score"} assert len(all_configs) == num_iters assert len(all_scores) == num_iters @@ -308,26 +342,32 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses(BaseOptimizer, # type: ignore[type-abstract] - pkg_name='mlos_core') +optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( + BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] +) assert optimizer_subclasses -@pytest.mark.parametrize(('optimizer_class'), optimizer_subclasses) +@pytest.mark.parametrize(("optimizer_class"), optimizer_subclasses) def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: """Test that all optimizer classes are listed in the OptimizerType enum.""" optimizer_type_classes = {member.value for member in OptimizerType} assert optimizer_class in optimizer_type_classes -@pytest.mark.parametrize(('optimizer_type', 'kwargs'), [ - # Default optimizer - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in OptimizerType], - # Optimizer with non-empty kwargs argument -]) -def test_mixed_numerics_type_input_space_types(optimizer_type: Optional[OptimizerType], kwargs: Optional[dict]) -> None: +@pytest.mark.parametrize( + ("optimizer_type", "kwargs"), + [ + # Default optimizer + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in OptimizerType], + # Optimizer with non-empty kwargs argument + ], +) +def test_mixed_numerics_type_input_space_types( + optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] +) -> None: """Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. """ @@ -341,19 +381,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: input_space = CS.ConfigurationSpace(seed=SEED) # add a mix of numeric datatypes - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name='x', lower=0, upper=5)) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0.0, upper=5.0)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="x", lower=0, upper=5)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0.0, upper=5.0)) if optimizer_type is None: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_kwargs=kwargs, ) else: optimizer = OptimizerFactory.create( parameter_space=input_space, - optimization_targets=['score'], + optimization_targets=["score"], optimizer_type=optimizer_type, optimizer_kwargs=kwargs, ) @@ -367,12 +407,14 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion, metadata = optimizer.suggest() assert isinstance(suggestion, pd.DataFrame) - assert (suggestion.columns == ['x', 'y']).all() + assert (suggestion.columns == ["x", "y"]).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion['x'].iloc[0], np.integer) - assert isinstance(suggestion['y'].iloc[0], np.floating) + assert isinstance(suggestion["x"].iloc[0], np.integer) + assert isinstance(suggestion["y"].iloc[0], np.floating) # Check that suggestion is in the space - test_configuration = CS.Configuration(optimizer.parameter_space, suggestion.astype('O').iloc[0].to_dict()) + test_configuration = CS.Configuration( + optimizer.parameter_space, suggestion.astype("O").iloc[0].to_dict() + ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py index 07f23507d9..5d394cf4e2 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/identity_adapter_test.py @@ -16,22 +16,33 @@ def test_identity_adapter() -> None: """Tests identity adapter.""" input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='float_1', lower=0, upper=100)) + CS.UniformFloatHyperparameter(name="float_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) adapter = IdentityAdapter(orig_parameter_space=input_space) num_configs = 10 - for sampled_config in input_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable # (false positive) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + for sampled_config in input_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # (false positive) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) assert target_config_df.equals(sampled_config_df) - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config orig_config_df = adapter.transform(target_config_df) assert orig_config_df.equals(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) assert orig_config == sampled_config diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index 5bddbaf807..f557b05883 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -26,42 +26,53 @@ def construct_parameter_space( for idx in range(n_continuous_params): input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name=f'cont_{idx}', lower=0, upper=64)) + CS.UniformFloatHyperparameter(name=f"cont_{idx}", lower=0, upper=64) + ) for idx in range(n_integer_params): input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name=f'int_{idx}', lower=-1, upper=256)) + CS.UniformIntegerHyperparameter(name=f"int_{idx}", lower=-1, upper=256) + ) for idx in range(n_categorical_params): input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name=f'str_{idx}', choices=[f'option_{idx}' for idx in range(5)])) + CS.CategoricalHyperparameter( + name=f"str_{idx}", choices=[f"option_{idx}" for idx in range(5)] + ) + ) return input_space -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_num_low_dims( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """Tests LlamaTune's low-to-high space projection method.""" input_space = construct_parameter_space(**param_space_kwargs) # Number of target parameter space dimensions should be fewer than those of the original space with pytest.raises(ValueError): LlamaTuneAdapter( - orig_parameter_space=input_space, - num_low_dims=len(list(input_space.keys())) + orig_parameter_space=input_space, num_low_dims=len(list(input_space.keys())) ) # Enable only low-dimensional space projections @@ -69,13 +80,15 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N orig_parameter_space=input_space, num_low_dims=num_target_space_dims, special_param_values=None, - max_unique_values_per_param=None + max_unique_values_per_param=None, ) sampled_configs = adapter.target_parameter_space.sample_configuration(size=100) for sampled_config in sampled_configs: # pylint: disable=not-an-iterable # (false positive) # Transform low-dim config to high-dim point/config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) # High-dim (i.e., original) config should be valid @@ -86,18 +99,28 @@ def test_num_low_dims(num_target_space_dims: int, param_space_kwargs: dict) -> N target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == sampled_config # Try inverse projection (i.e., high-to-low) for previously unseen configs unseen_sampled_configs = adapter.target_parameter_space.sample_configuration(size=25) - for unseen_sampled_config in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) - if unseen_sampled_config in sampled_configs: # pylint: disable=unsupported-membership-test # (false positive) + for ( + unseen_sampled_config + ) in unseen_sampled_configs: # pylint: disable=not-an-iterable # (false positive) + if ( + unseen_sampled_config in sampled_configs + ): # pylint: disable=unsupported-membership-test # (false positive) continue - unseen_sampled_config_df = pd.DataFrame([unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys())) + unseen_sampled_config_df = pd.DataFrame( + [unseen_sampled_config.values()], columns=list(unseen_sampled_config.keys()) + ) with pytest.raises(ValueError): - _ = adapter.inverse_transform(unseen_sampled_config_df) # pylint: disable=redefined-variable-type + _ = adapter.inverse_transform( + unseen_sampled_config_df + ) # pylint: disable=redefined-variable-type def test_special_parameter_values_validation() -> None: @@ -106,15 +129,14 @@ def test_special_parameter_values_validation() -> None: """ input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str', choices=[f'choice_{idx}' for idx in range(5)])) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont', lower=-1, upper=100)) - input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int', lower=0, upper=100)) + CS.CategoricalHyperparameter(name="str", choices=[f"choice_{idx}" for idx in range(5)]) + ) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont", lower=-1, upper=100)) + input_space.add_hyperparameter(CS.UniformIntegerHyperparameter(name="int", lower=0, upper=100)) # Only UniformIntegerHyperparameters are currently supported with pytest.raises(NotImplementedError): - special_param_values_dict_1 = {'str': 'choice_1'} + special_param_values_dict_1 = {"str": "choice_1"} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -123,7 +145,7 @@ def test_special_parameter_values_validation() -> None: ) with pytest.raises(NotImplementedError): - special_param_values_dict_2 = {'cont': -1} + special_param_values_dict_2 = {"cont": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -132,8 +154,8 @@ def test_special_parameter_values_validation() -> None: ) # Special value should belong to parameter value domain - with pytest.raises(ValueError, match='value domain'): - special_param_values_dict = {'int': -1} + with pytest.raises(ValueError, match="value domain"): + special_param_values_dict = {"int": -1} LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=2, @@ -143,15 +165,15 @@ def test_special_parameter_values_validation() -> None: # Invalid dicts; ValueError should be thrown invalid_special_param_values_dicts: List[Dict[str, Any]] = [ - {'int-Q': 0}, # parameter does not exist - {'int': {0: 0.2}}, # invalid definition - {'int': 0.2}, # invalid parameter value - {'int': (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) - {'int': [0, 0]}, # duplicate special values - {'int': []}, # empty list - {'int': [{0: 0.2}]}, - {'int': [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct - {'int': [(0, 0.1), (0, 0.2)]}, # duplicate special values + {"int-Q": 0}, # parameter does not exist + {"int": {0: 0.2}}, # invalid definition + {"int": 0.2}, # invalid parameter value + {"int": (0.4, 0)}, # (biasing %, special value) instead of (special value, biasing %) + {"int": [0, 0]}, # duplicate special values + {"int": []}, # empty list + {"int": [{0: 0.2}]}, + {"int": [(0.4, 0), (1, 0.7)]}, # first tuple is inverted; second is correct + {"int": [(0, 0.1), (0, 0.2)]}, # duplicate special values ] for spv_dict in invalid_special_param_values_dicts: with pytest.raises(ValueError): @@ -164,13 +186,13 @@ def test_special_parameter_values_validation() -> None: # Biasing percentage of special value(s) are invalid invalid_special_param_values_dicts = [ - {'int': (0, 1.1)}, # >1 probability - {'int': (0, 0)}, # Zero probability - {'int': (0, -0.1)}, # Negative probability - {'int': (0, 20)}, # 2,000% instead of 20% - {'int': [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% - {'int': [(0, 0.4), (1, 0.7)]}, # combined probability >100% - {'int': [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. + {"int": (0, 1.1)}, # >1 probability + {"int": (0, 0)}, # Zero probability + {"int": (0, -0.1)}, # Negative probability + {"int": (0, 20)}, # 2,000% instead of 20% + {"int": [0, 1, 2, 3, 4, 5]}, # default biasing is 20%; 6 values * 20% > 100% + {"int": [(0, 0.4), (1, 0.7)]}, # combined probability >100% + {"int": [(0, -0.4), (1, 0.7)]}, # probability for value 0 is invalid. ] for spv_dict in invalid_special_param_values_dicts: @@ -186,19 +208,25 @@ def test_special_parameter_values_validation() -> None: def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[CS.Configuration]: for sampled_config in adapter.target_parameter_space.sample_configuration(size=num_configs): # Transform low-dim config to high-dim config - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) orig_config_df = adapter.transform(sampled_config_df) - orig_config = CS.Configuration(adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict()) + orig_config = CS.Configuration( + adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + ) yield orig_config -def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex +def test_special_parameter_values_biasing() -> None: # pylint: disable=too-complex """Tests LlamaTune's special parameter values biasing methodology.""" input_space = CS.ConfigurationSpace(seed=1234) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_1", lower=0, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=100)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=100) + ) num_configs = 400 bias_percentage = LlamaTuneAdapter.DEFAULT_SPECIAL_PARAM_VALUE_BIASING_PERCENTAGE @@ -206,10 +234,10 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co # Single parameter; single special value special_param_value_dicts: List[Dict[str, Any]] = [ - {'int_1': 0}, - {'int_1': (0, bias_percentage)}, - {'int_1': [0]}, - {'int_1': [(0, bias_percentage)]} + {"int_1": 0}, + {"int_1": (0, bias_percentage)}, + {"int_1": [0]}, + {"int_1": [(0, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -221,13 +249,14 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_value_occurrences = sum( - 1 for config in gen_random_configs(adapter, num_configs) if config['int_1'] == 0) + 1 for config in gen_random_configs(adapter, num_configs) if config["int_1"] == 0 + ) assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences # Single parameter; multiple special values special_param_value_dicts = [ - {'int_1': [0, 1]}, - {'int_1': [(0, bias_percentage), (1, bias_percentage)]} + {"int_1": [0, 1]}, + {"int_1": [(0, bias_percentage), (1, bias_percentage)]}, ] for spv_dict in special_param_value_dicts: @@ -240,9 +269,9 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co special_values_occurrences = {0: 0, 1: 0} for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: + if config["int_1"] == 0: special_values_occurrences[0] += 1 - elif config['int_1'] == 1: + elif config["int_1"] == 1: special_values_occurrences[1] += 1 assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_occurrences[0] @@ -250,8 +279,8 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co # Multiple parameters; multiple special values; different biasing percentage spv_dict = { - 'int_1': [(0, bias_percentage), (1, bias_percentage / 2)], - 'int_2': [(2, bias_percentage / 2), (100, bias_percentage * 1.5)] + "int_1": [(0, bias_percentage), (1, bias_percentage / 2)], + "int_2": [(2, bias_percentage / 2), (100, bias_percentage * 1.5)], } adapter = LlamaTuneAdapter( orig_parameter_space=input_space, @@ -261,42 +290,52 @@ def test_special_parameter_values_biasing() -> None: # pylint: disable=too-co ) special_values_instances: Dict[str, Dict[int, int]] = { - 'int_1': {0: 0, 1: 0}, - 'int_2': {2: 0, 100: 0}, + "int_1": {0: 0, 1: 0}, + "int_2": {2: 0, 100: 0}, } for config in gen_random_configs(adapter, num_configs): - if config['int_1'] == 0: - special_values_instances['int_1'][0] += 1 - elif config['int_1'] == 1: - special_values_instances['int_1'][1] += 1 - - if config['int_2'] == 2: - special_values_instances['int_2'][2] += 1 - elif config['int_2'] == 100: - special_values_instances['int_2'][100] += 1 - - assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances['int_1'][0] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_1'][1] - assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances['int_2'][2] - assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances['int_2'][100] + if config["int_1"] == 0: + special_values_instances["int_1"][0] += 1 + elif config["int_1"] == 1: + special_values_instances["int_1"][1] += 1 + + if config["int_2"] == 2: + special_values_instances["int_2"][2] += 1 + elif config["int_2"] == 100: + special_values_instances["int_2"][100] += 1 + + assert (1 - eps) * int(num_configs * bias_percentage) <= special_values_instances["int_1"][0] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_1"][ + 1 + ] + assert (1 - eps) * int(num_configs * bias_percentage / 2) <= special_values_instances["int_2"][ + 2 + ] + assert (1 - eps) * int(num_configs * bias_percentage * 1.5) <= special_values_instances[ + "int_2" + ][100] def test_max_unique_values_per_param() -> None: """Tests LlamaTune's parameter values discretization implementation.""" # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5)) input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_1', lower=0, upper=5)) - input_space.add_hyperparameter( - CS.UniformFloatHyperparameter(name='cont_2', lower=1, upper=100)) + CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_1', lower=1, upper=10)) + CS.UniformIntegerHyperparameter(name="int_1", lower=1, upper=10) + ) input_space.add_hyperparameter( - CS.UniformIntegerHyperparameter(name='int_2', lower=0, upper=2048)) + CS.UniformIntegerHyperparameter(name="int_2", lower=0, upper=2048) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_1', choices=['on', 'off'])) + CS.CategoricalHyperparameter(name="str_1", choices=["on", "off"]) + ) input_space.add_hyperparameter( - CS.CategoricalHyperparameter(name='str_2', choices=[f'choice_{idx}' for idx in range(10)])) + CS.CategoricalHyperparameter(name="str_2", choices=[f"choice_{idx}" for idx in range(10)]) + ) # Restrict the number of unique parameter values num_configs = 200 @@ -319,23 +358,30 @@ def test_max_unique_values_per_param() -> None: assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: dict) -> None: # pylint: disable=too-many-locals +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_approx_inverse_mapping( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: # pylint: disable=too-many-locals """Tests LlamaTune's approximate high-to-low space projection method, using pseudo- inverse. """ @@ -350,9 +396,11 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: use_approximate_reverse_mapping=False, ) - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.raises(ValueError): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) _ = adapter.inverse_transform(sampled_config_df) # Enable low-dimensional space projection *and* reverse mapping @@ -365,39 +413,61 @@ def test_approx_inverse_mapping(num_target_space_dims: int, param_space_kwargs: ) # Warning should be printed the first time - sampled_config = input_space.sample_configuration() # size=1) + sampled_config = input_space.sample_configuration() # size=1) with pytest.warns(UserWarning): - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) # Test inverse transform with 100 random configs for _ in range(100): - sampled_config = input_space.sample_configuration() # size=1) - sampled_config_df = pd.DataFrame([sampled_config.values()], columns=list(sampled_config.keys())) + sampled_config = input_space.sample_configuration() # size=1) + sampled_config_df = pd.DataFrame( + [sampled_config.values()], columns=list(sampled_config.keys()) + ) target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) adapter.target_parameter_space.check_configuration(target_config) -@pytest.mark.parametrize(('num_low_dims', 'special_param_values', 'max_unique_values_per_param'), ([ - (num_low_dims, special_param_values, max_unique_values_per_param) - for num_low_dims in (8, 16) - for special_param_values in ( - {'int_1': -1, 'int_2': -1, 'int_3': -1, 'int_4': [-1, 0]}, - {'int_1': (-1, 0.1), 'int_2': -1, 'int_3': (-1, 0.3), 'int_4': [(-1, 0.1), (0, 0.2)]}, - ) - for max_unique_values_per_param in (50, 250) -])) -def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int) -> None: +@pytest.mark.parametrize( + ("num_low_dims", "special_param_values", "max_unique_values_per_param"), + ( + [ + (num_low_dims, special_param_values, max_unique_values_per_param) + for num_low_dims in (8, 16) + for special_param_values in ( + {"int_1": -1, "int_2": -1, "int_3": -1, "int_4": [-1, 0]}, + { + "int_1": (-1, 0.1), + "int_2": -1, + "int_3": (-1, 0.3), + "int_4": [(-1, 0.1), (0, 0.2)], + }, + ) + for max_unique_values_per_param in (50, 250) + ] + ), +) +def test_llamatune_pipeline( + num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int +) -> None: """Tests LlamaTune space adapter when all components are active.""" # pylint: disable=too-many-locals # Define config space with a mix of different parameter types - input_space = construct_parameter_space(n_continuous_params=10, n_integer_params=10, n_categorical_params=5) + input_space = construct_parameter_space( + n_continuous_params=10, n_integer_params=10, n_categorical_params=5 + ) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, num_low_dims=num_low_dims, @@ -406,13 +476,16 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u ) special_value_occurrences = { + # pylint: disable=protected-access param: {special_value: 0 for special_value, _ in tuples_list} - for param, tuples_list in adapter._special_param_values_dict.items() # pylint: disable=protected-access + for param, tuples_list in adapter._special_param_values_dict.items() } unique_values_dict: Dict[str, Set] = {param: set() for param in input_space.keys()} num_configs = 1000 - for config in adapter.target_parameter_space.sample_configuration(size=num_configs): # pylint: disable=not-an-iterable + for config in adapter.target_parameter_space.sample_configuration( + size=num_configs + ): # pylint: disable=not-an-iterable # Transform low-dim config to high-dim point/config sampled_config_df = pd.DataFrame([config.values()], columns=list(config.keys())) orig_config_df = adapter.transform(sampled_config_df) @@ -423,7 +496,9 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u # Transform high-dim config back to low-dim target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same - target_config = CS.Configuration(adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict()) + target_config = CS.Configuration( + adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + ) assert target_config == config for param, value in orig_config.items(): @@ -437,35 +512,48 @@ def test_llamatune_pipeline(num_low_dims: int, special_param_values: dict, max_u # Ensure that occurrences of special values do not significantly deviate from expected eps = 0.2 - for param, tuples_list in adapter._special_param_values_dict.items(): # pylint: disable=protected-access + for ( + param, + tuples_list, + ) in adapter._special_param_values_dict.items(): # pylint: disable=protected-access for value, bias_percentage in tuples_list: - assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[param][value] + assert (1 - eps) * int(num_configs * bias_percentage) <= special_value_occurrences[ + param + ][value] # Ensure that number of unique values is less than the maximum number allowed for _, unique_values in unique_values_dict.items(): assert len(unique_values) <= max_unique_values_per_param -@pytest.mark.parametrize(('num_target_space_dims', 'param_space_kwargs'), ([ - (num_target_space_dims, param_space_kwargs) - for num_target_space_dims in (2, 4) - for num_orig_space_factor in (1.5, 4) - for param_space_kwargs in ( - {'n_continuous_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_integer_params': int(num_target_space_dims * num_orig_space_factor)}, - {'n_categorical_params': int(num_target_space_dims * num_orig_space_factor)}, - # Mix of all three types - { - 'n_continuous_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_integer_params': int(num_target_space_dims * num_orig_space_factor / 3), - 'n_categorical_params': int(num_target_space_dims * num_orig_space_factor / 3), - }, - ) -])) -def test_deterministic_behavior_for_same_seed(num_target_space_dims: int, param_space_kwargs: dict) -> None: +@pytest.mark.parametrize( + ("num_target_space_dims", "param_space_kwargs"), + ( + [ + (num_target_space_dims, param_space_kwargs) + for num_target_space_dims in (2, 4) + for num_orig_space_factor in (1.5, 4) + for param_space_kwargs in ( + {"n_continuous_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_integer_params": int(num_target_space_dims * num_orig_space_factor)}, + {"n_categorical_params": int(num_target_space_dims * num_orig_space_factor)}, + # Mix of all three types + { + "n_continuous_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_integer_params": int(num_target_space_dims * num_orig_space_factor / 3), + "n_categorical_params": int(num_target_space_dims * num_orig_space_factor / 3), + }, + ) + ] + ), +) +def test_deterministic_behavior_for_same_seed( + num_target_space_dims: int, param_space_kwargs: dict +) -> None: """Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. """ + def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: input_space = construct_parameter_space(**param_space_kwargs, seed=seed) @@ -478,7 +566,9 @@ def generate_target_param_space_configs(seed: int) -> List[CS.Configuration]: use_approximate_reverse_mapping=False, ) - sample_configs: List[CS.Configuration] = adapter.target_parameter_space.sample_configuration(size=100) + sample_configs: List[CS.Configuration] = ( + adapter.target_parameter_space.sample_configuration(size=100) + ) return sample_configs assert generate_target_param_space_configs(42) == generate_target_param_space_configs(42) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index fd22d0c257..188a0300e7 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -21,37 +21,48 @@ from mlos_core.tests import get_all_concrete_subclasses -@pytest.mark.parametrize(('space_adapter_type'), [ - # Enumerate all supported SpaceAdapters - # *[member for member in SpaceAdapterType], - *list(SpaceAdapterType), -]) +@pytest.mark.parametrize( + ("space_adapter_type"), + [ + # Enumerate all supported SpaceAdapters + # *[member for member in SpaceAdapterType], + *list(SpaceAdapterType), + ], +) def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: """Test that all optimizer types are listed in the ConcreteOptimizer constraints.""" # pylint: disable=no-member - assert space_adapter_type.value in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] - - -@pytest.mark.parametrize(('space_adapter_type', 'kwargs'), [ - # Default space adapter - (None, {}), - # Enumerate all supported Optimizers - *[(member, {}) for member in SpaceAdapterType], -]) -def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict]) -> None: + assert ( + space_adapter_type.value + in ConcreteSpaceAdapter.__constraints__ # type: ignore[attr-defined] + ) + + +@pytest.mark.parametrize( + ("space_adapter_type", "kwargs"), + [ + # Default space adapter + (None, {}), + # Enumerate all supported Optimizers + *[(member, {}) for member in SpaceAdapterType], + ], +) +def test_create_space_adapter_with_factory_method( + space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] +) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='x', lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="x", lower=0, upper=1)) # Add a single continuous input dimension between 0 and 1. - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name='y', lower=0, upper=1)) + input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="y", lower=0, upper=1)) # Adjust some kwargs for specific space adapters if space_adapter_type is SpaceAdapterType.LLAMATUNE: if kwargs is None: kwargs = {} - kwargs.setdefault('num_low_dims', 1) + kwargs.setdefault("num_low_dims", 1) space_adapter: BaseSpaceAdapter if space_adapter_type is None: @@ -69,19 +80,24 @@ def test_create_space_adapter_with_factory_method(space_adapter_type: Optional[S assert space_adapter is not None assert space_adapter.orig_parameter_space is not None myrepr = repr(space_adapter) - assert myrepr.startswith(space_adapter_type.value.__name__), \ - f"Expected {space_adapter_type.value.__name__} but got {myrepr}" + assert myrepr.startswith( + space_adapter_type.value.__name__ + ), f"Expected {space_adapter_type.value.__name__} but got {myrepr}" # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. -space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = \ - get_all_concrete_subclasses(BaseSpaceAdapter, pkg_name='mlos_core') # type: ignore[type-abstract] +space_adapter_subclasses: List[Type[BaseSpaceAdapter]] = get_all_concrete_subclasses( + BaseSpaceAdapter, # type: ignore[type-abstract] + pkg_name="mlos_core", +) assert space_adapter_subclasses -@pytest.mark.parametrize(('space_adapter_class'), space_adapter_subclasses) +@pytest.mark.parametrize(("space_adapter_class"), space_adapter_subclasses) def test_space_adapter_type_defs(space_adapter_class: Type[BaseSpaceAdapter]) -> None: """Test that all space adapter classes are listed in the SpaceAdapterType enum.""" - space_adapter_type_classes = {space_adapter_type.value for space_adapter_type in SpaceAdapterType} + space_adapter_type_classes = { + space_adapter_type.value for space_adapter_type in SpaceAdapterType + } assert space_adapter_class in space_adapter_type_classes diff --git a/mlos_core/mlos_core/tests/spaces/spaces_test.py b/mlos_core/mlos_core/tests/spaces/spaces_test.py index 35a8f9ebb3..9d4c17f160 100644 --- a/mlos_core/mlos_core/tests/spaces/spaces_test.py +++ b/mlos_core/mlos_core/tests/spaces/spaces_test.py @@ -39,9 +39,9 @@ def assert_is_uniform(arr: npt.NDArray) -> None: assert np.isclose(frequencies.sum(), 1) _f_chi_sq, f_p_value = scipy.stats.chisquare(frequencies) - assert np.isclose(kurtosis, -1.2, atol=.1) - assert p_value > .3 - assert f_p_value > .5 + assert np.isclose(kurtosis, -1.2, atol=0.1) + assert p_value > 0.3 + assert f_p_value > 0.5 def assert_is_log_uniform(arr: npt.NDArray, base: float = np.e) -> None: @@ -66,11 +66,12 @@ def test_is_log_uniform() -> None: def invalid_conversion_function(*args: Any) -> NoReturn: """A quick dummy function for the base class to make pylint happy.""" - raise NotImplementedError('subclass must override conversion_function') + raise NotImplementedError("subclass must override conversion_function") class BaseConversion(metaclass=ABCMeta): """Base class for testing optimizer space conversions.""" + conversion_function: Callable[..., OptimizerSpace] = invalid_conversion_function @abstractmethod @@ -142,8 +143,8 @@ def test_uniform_samples(self) -> None: assert_is_uniform(uniform) # Check that we get both ends of the sampled range returned to us. - assert input_space['c'].lower in integer_uniform - assert input_space['c'].upper in integer_uniform + assert input_space["c"].lower in integer_uniform + assert input_space["c"].upper in integer_uniform # integer uniform assert_is_uniform(integer_uniform) @@ -157,13 +158,13 @@ def test_uniform_categorical(self) -> None: assert 35 < counts[1] < 65 def test_weighted_categorical(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_int_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") def test_log_float_spaces(self) -> None: - raise NotImplementedError('subclass must override') + raise NotImplementedError("subclass must override") class TestFlamlConversion(BaseConversion): @@ -171,13 +172,19 @@ class TestFlamlConversion(BaseConversion): conversion_function = staticmethod(configspace_to_flaml_space) - def sample(self, config_space: FlamlSpace, n_samples: int = 1) -> npt.NDArray: # type: ignore[override] + def sample( + self, + config_space: FlamlSpace, # type: ignore[override] + n_samples: int = 1, + ) -> npt.NDArray: assert isinstance(config_space, dict) assert isinstance(next(iter(config_space.values())), flaml.tune.sample.Domain) - ret: npt.NDArray = np.array([domain.sample(size=n_samples) for domain in config_space.values()]).T + ret: npt.NDArray = np.array( + [domain.sample(size=n_samples) for domain in config_space.values()] + ).T return ret - def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] + def get_parameter_names(self, config_space: FlamlSpace) -> List[str]: # type: ignore[override] assert isinstance(config_space, dict) ret: List[str] = list(config_space.keys()) return ret @@ -198,7 +205,9 @@ def test_dimensionality(self) -> None: def test_weighted_categorical(self) -> None: np.random.seed(42) input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1])) + input_space.add_hyperparameter( + CS.CategoricalHyperparameter("c", choices=["foo", "bar"], weights=[0.9, 0.1]) + ) with pytest.raises(ValueError, match="non-uniform"): configspace_to_flaml_space(input_space) @@ -207,7 +216,9 @@ def test_log_int_spaces(self) -> None: np.random.seed(42) # integer is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True)) + input_space.add_hyperparameter( + CS.UniformIntegerHyperparameter("d", lower=1, upper=20, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -225,7 +236,9 @@ def test_log_float_spaces(self) -> None: # continuous is supported input_space = CS.ConfigurationSpace() - input_space.add_hyperparameter(CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter("b", lower=1, upper=5, log=True) + ) converted_space = configspace_to_flaml_space(input_space) # test log integer sampling @@ -235,6 +248,6 @@ def test_log_float_spaces(self) -> None: assert_is_log_uniform(float_log_uniform) -if __name__ == '__main__': +if __name__ == "__main__": # For attaching debugger debugging: pytest.main(["-vv", "-k", "test_log_int_spaces", __file__]) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index f6933cbb6a..50c6880f87 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -27,7 +27,9 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: return pd.DataFrame([dict(config)]) -def normalize_config(config_space: ConfigurationSpace, config: Union[Configuration, dict]) -> Configuration: +def normalize_config( + config_space: ConfigurationSpace, config: Union[Configuration, dict] +) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. @@ -48,8 +50,6 @@ def normalize_config(config_space: ConfigurationSpace, config: Union[Configurati """ cs_config = Configuration(config_space, values=config, allow_inactive_with_values=True) return Configuration( - config_space, values={ - key: cs_config[key] - for key in config_space.get_active_hyperparameters(cs_config) - } + config_space, + values={key: cs_config[key] for key in config_space.get_active_hyperparameters(cs_config)}, ) diff --git a/mlos_core/mlos_core/version.py b/mlos_core/mlos_core/version.py index 61eb665064..f8bc82063c 100644 --- a/mlos_core/mlos_core/version.py +++ b/mlos_core/mlos_core/version.py @@ -5,7 +5,7 @@ """Version number for the mlos_core package.""" # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 3771d73f43..853615274f 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -19,15 +19,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -47,52 +48,56 @@ # we return nothing when the file is not available. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass - 'flaml': ['flaml[blendsearch]'], - 'smac': ['smac>=2.0.0'], # NOTE: Major refactoring on SMAC starting from v2.0.0 + "flaml": ["flaml[blendsearch]"], + "smac": ["smac>=2.0.0"], # NOTE: Major refactoring on SMAC starting from v2.0.0 } # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'scikit-learn>=1.2', - 'joblib>=1.1.1', # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which isn't currently released - 'scipy>=1.3.2', - 'numpy>=1.24', 'numpy<2.0.0', # FIXME: https://github.com/numpy/numpy/issues/26710 - 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', + "scikit-learn>=1.2", + # CVE-2022-21797: scikit-learn dependency, addressed in 1.2.0dev0, which + # isn't currently released + "joblib>=1.1.1", + "scipy>=1.3.2", + "numpy>=1.24", + "numpy<2.0.0", # FIXME: https://github.com/numpy/numpy/issues/26710 + 'pandas >= 2.2.0;python_version>="3.9"', + 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - 'ConfigSpace>=0.7.1', + "ConfigSpace>=0.7.1", ], extras_require=extra_requires, **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"), diff --git a/mlos_viz/mlos_viz/__init__.py b/mlos_viz/mlos_viz/__init__.py index a2a36b54a9..ddf7727ec8 100644 --- a/mlos_viz/mlos_viz/__init__.py +++ b/mlos_viz/mlos_viz/__init__.py @@ -20,7 +20,7 @@ class MlosVizMethod(Enum): """What method to use for visualizing the experiment results.""" DABL = "dabl" - AUTO = DABL # use dabl as the current default + AUTO = DABL # use dabl as the current default def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) -> None: @@ -36,17 +36,21 @@ def ignore_plotter_warnings(plotter_method: MlosVizMethod = MlosVizMethod.AUTO) base.ignore_plotter_warnings() if plotter_method == MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.ignore_plotter_warnings() else: raise NotImplementedError(f"Unhandled method: {plotter_method}") -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - plotter_method: MlosVizMethod = MlosVizMethod.AUTO, - filter_warnings: bool = True, - **kwargs: Any) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + plotter_method: MlosVizMethod = MlosVizMethod.AUTO, + filter_warnings: bool = True, + **kwargs: Any, +) -> None: """ Plots the results of the experiment. @@ -78,6 +82,7 @@ def plot(exp_data: Optional[ExperimentData] = None, *, if MlosVizMethod.DABL: import mlos_viz.dabl # pylint: disable=import-outside-toplevel + mlos_viz.dabl.plot(exp_data, results_df=results_df, objectives=objectives) else: raise NotImplementedError(f"Unhandled method: {plotter_method}") diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 10b9946051..84e1fb3bd3 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -18,7 +18,7 @@ from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_viz.util import expand_results_data_args -_SEABORN_VERS = version('seaborn') +_SEABORN_VERS = version("seaborn") def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: @@ -28,7 +28,7 @@ def _get_kwarg_defaults(target: Callable, **kwargs: Any) -> Dict[str, Any]: Note: this only works with non-positional kwargs (e.g., those after a * arg). """ target_kwargs = {} - for kword in target.__kwdefaults__: # or {} # intentionally omitted for now + for kword in target.__kwdefaults__: # or {} # intentionally omitted for now if kword in kwargs: target_kwargs[kword] = kwargs[kword] return target_kwargs @@ -39,14 +39,19 @@ def ignore_plotter_warnings() -> None: adding them to the warnings filter. """ warnings.filterwarnings("ignore", category=FutureWarning) - if _SEABORN_VERS <= '0.13.1': - warnings.filterwarnings("ignore", category=DeprecationWarning, module="seaborn", # but actually comes from pandas - message="is_categorical_dtype is deprecated and will be removed in a future version.") + if _SEABORN_VERS <= "0.13.1": + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="seaborn", # but actually comes from pandas + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) -def _add_groupby_desc_column(results_df: pandas.DataFrame, - groupby_columns: Optional[List[str]] = None, - ) -> Tuple[pandas.DataFrame, List[str], str]: +def _add_groupby_desc_column( + results_df: pandas.DataFrame, + groupby_columns: Optional[List[str]] = None, +) -> Tuple[pandas.DataFrame, List[str], str]: """ Adds a group descriptor column to the results_df. @@ -64,17 +69,19 @@ def _add_groupby_desc_column(results_df: pandas.DataFrame, if groupby_columns is None: groupby_columns = ["tunable_config_trial_group_id", "tunable_config_id"] groupby_column = ",".join(groupby_columns) - results_df[groupby_column] = results_df[groupby_columns].astype(str).apply( - lambda x: ",".join(x), axis=1) # pylint: disable=unnecessary-lambda + results_df[groupby_column] = ( + results_df[groupby_columns].astype(str).apply(lambda x: ",".join(x), axis=1) + ) # pylint: disable=unnecessary-lambda groupby_columns.append(groupby_column) return (results_df, groupby_columns, groupby_column) -def augment_results_df_with_config_trial_group_stats(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - requested_result_cols: Optional[Iterable[str]] = None, - ) -> pandas.DataFrame: +def augment_results_df_with_config_trial_group_stats( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + requested_result_cols: Optional[Iterable[str]] = None, +) -> pandas.DataFrame: # pylint: disable=too-complex """ Add a number of useful statistical measure columns to the results dataframe. @@ -131,30 +138,47 @@ def augment_results_df_with_config_trial_group_stats(exp_data: Optional[Experime raise ValueError(f"Not enough data: {len(results_groups)}") if requested_result_cols is None: - result_cols = set(col for col in results_df.columns if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX)) + result_cols = set( + col + for col in results_df.columns + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + ) else: - result_cols = set(col for col in requested_result_cols - if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns) - result_cols.update(set(ExperimentData.RESULT_COLUMN_PREFIX + col for col in requested_result_cols - if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns)) + result_cols = set( + col + for col in requested_result_cols + if col.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and col in results_df.columns + ) + result_cols.update( + set( + ExperimentData.RESULT_COLUMN_PREFIX + col + for col in requested_result_cols + if ExperimentData.RESULT_COLUMN_PREFIX in results_df.columns + ) + ) def compute_zscore_for_group_agg( - results_groups_perf: "SeriesGroupBy", - stats_df: pandas.DataFrame, - result_col: str, - agg: Union[Literal["mean"], Literal["var"], Literal["std"]] + results_groups_perf: "SeriesGroupBy", + stats_df: pandas.DataFrame, + result_col: str, + agg: Union[Literal["mean"], Literal["var"], Literal["std"]], ) -> None: - results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? - # Compute the zscore of the chosen aggregate performance of each group into each row in the dataframe. + results_groups_perf_aggs = results_groups_perf.agg(agg) # TODO: avoid recalculating? + # Compute the zscore of the chosen aggregate performance of each group into + # each row in the dataframe. stats_df[result_col + f".{agg}_mean"] = results_groups_perf_aggs.mean() stats_df[result_col + f".{agg}_stddev"] = results_groups_perf_aggs.std() - stats_df[result_col + f".{agg}_zscore"] = \ - (stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"]) \ - / stats_df[result_col + f".{agg}_stddev"] - stats_df.drop(columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True) + stats_df[result_col + f".{agg}_zscore"] = ( + stats_df[result_col + f".{agg}"] - stats_df[result_col + f".{agg}_mean"] + ) / stats_df[result_col + f".{agg}_stddev"] + stats_df.drop( + columns=[result_col + ".var_" + agg for agg in ("mean", "stddev")], inplace=True + ) augmented_results_df = results_df - augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform("count") + augmented_results_df["tunable_config_trial_group_size"] = results_groups["trial_id"].transform( + "count" + ) for result_col in result_cols: if not result_col.startswith(ExperimentData.RESULT_COLUMN_PREFIX): continue @@ -173,20 +197,21 @@ def compute_zscore_for_group_agg( compute_zscore_for_group_agg(results_groups_perf, stats_df, result_col, "var") quantiles = [0.50, 0.75, 0.90, 0.95, 0.99] - for quantile in quantiles: # TODO: can we do this in one pass? + for quantile in quantiles: # TODO: can we do this in one pass? quantile_col = f"{result_col}.p{int(quantile * 100)}" stats_df[quantile_col] = results_groups_perf.transform("quantile", quantile) augmented_results_df = pandas.concat([augmented_results_df, stats_df], axis=1) return augmented_results_df -def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - top_n_configs: int = 10, - method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", - ) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: +def limit_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + top_n_configs: int = 10, + method: Literal["mean", "p50", "p75", "p90", "p95", "p99"] = "mean", +) -> Tuple[pandas.DataFrame, List[int], Dict[str, bool]]: # pylint: disable=too-many-locals """ Utility function to process the results and determine the best performing configs @@ -199,24 +224,30 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, results_df : Optional[pandas.DataFrame] The results dataframe to augment, by default None to use the results_df property. objectives : Iterable[str], optional - Which result column(s) to use for sorting the configs, and in which direction ("min" or "max"). + Which result column(s) to use for sorting the configs, and in which + direction ("min" or "max"). By default None to automatically select the experiment objectives. top_n_configs : int, optional How many configs to return, including the default, by default 20. method: Literal["mean", "median", "p50", "p75", "p90", "p95", "p99"] = "mean", - Which statistical method to use when sorting the config groups before determining the cutoff, by default "mean". + Which statistical method to use when sorting the config groups before + determining the cutoff, by default "mean". Returns ------- - (top_n_config_results_df, top_n_config_ids, orderby_cols) : Tuple[pandas.DataFrame, List[int], Dict[str, bool]] - The filtered results dataframe, the config ids, and the columns used to order the configs. + (top_n_config_results_df, top_n_config_ids, orderby_cols) : + Tuple[pandas.DataFrame, List[int], Dict[str, bool]] + The filtered results dataframe, the config ids, and the columns used to + order the configs. """ # Do some input checking first. if method not in ["mean", "median", "p50", "p75", "p90", "p95", "p99"]: raise ValueError(f"Invalid method: {method}") # Prepare the orderby columns. - (results_df, objs_cols) = expand_results_data_args(exp_data, results_df=results_df, objectives=objectives) + (results_df, objs_cols) = expand_results_data_args( + exp_data, results_df=results_df, objectives=objectives + ) assert isinstance(results_df, pandas.DataFrame) # Augment the results dataframe with some useful stats. @@ -229,13 +260,17 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # results_df is not None and is in fact a DataFrame, so we periodically assert # it in this func for now. assert results_df is not None - orderby_cols: Dict[str, bool] = {obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items()} + orderby_cols: Dict[str, bool] = { + obj_col + f".{method}": ascending for (obj_col, ascending) in objs_cols.items() + } config_id_col = "tunable_config_id" - group_id_col = "tunable_config_trial_group_id" # first trial_id per config group + group_id_col = "tunable_config_trial_group_id" # first trial_id per config group trial_id_col = "trial_id" - default_config_id = results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id + default_config_id = ( + results_df[trial_id_col].min() if exp_data is None else exp_data.default_tunable_config_id + ) assert default_config_id is not None, "Failed to determine default config id." # Filter out configs whose variance is too large. @@ -247,16 +282,18 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, singletons_mask = results_df["tunable_config_trial_group_size"] == 1 else: singletons_mask = results_df["tunable_config_trial_group_size"] > 1 - results_df = results_df.loc[( - (results_df[f"{obj_col}.var_zscore"].abs() < 2) - | (singletons_mask) - | (results_df[config_id_col] == default_config_id) - )] + results_df = results_df.loc[ + ( + (results_df[f"{obj_col}.var_zscore"].abs() < 2) + | (singletons_mask) + | (results_df[config_id_col] == default_config_id) + ) + ] assert results_df is not None # Also, filter results that are worse than the default. default_config_results_df = results_df.loc[results_df[config_id_col] == default_config_id] - for (orderby_col, ascending) in orderby_cols.items(): + for orderby_col, ascending in orderby_cols.items(): default_vals = default_config_results_df[orderby_col].unique() assert len(default_vals) == 1 default_val = default_vals[0] @@ -268,29 +305,38 @@ def limit_top_n_configs(exp_data: Optional[ExperimentData] = None, # Now regroup and filter to the top-N configs by their group performance dimensions. assert results_df is not None - group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[orderby_cols.keys()] - top_n_config_ids: List[int] = group_results_df.sort_values( - by=list(orderby_cols.keys()), ascending=list(orderby_cols.values())).head(top_n_configs).index.tolist() + group_results_df: pandas.DataFrame = results_df.groupby(config_id_col).first()[ + orderby_cols.keys() + ] + top_n_config_ids: List[int] = ( + group_results_df.sort_values( + by=list(orderby_cols.keys()), ascending=list(orderby_cols.values()) + ) + .head(top_n_configs) + .index.tolist() + ) # Remove the default config if it's included. We'll add it back later. if default_config_id in top_n_config_ids: top_n_config_ids.remove(default_config_id) # Get just the top-n config results. # Sort by the group ids. - top_n_config_results_df = results_df.loc[( - results_df[config_id_col].isin(top_n_config_ids) - )].sort_values([group_id_col, config_id_col, trial_id_col]) + top_n_config_results_df = results_df.loc[ + (results_df[config_id_col].isin(top_n_config_ids)) + ].sort_values([group_id_col, config_id_col, trial_id_col]) # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) - top_n_config_results_df = pandas.concat([default_config_results_df, top_n_config_results_df], axis=0) + top_n_config_results_df = pandas.concat( + [default_config_results_df, top_n_config_results_df], axis=0 + ) return (top_n_config_results_df, top_n_config_ids, orderby_cols) def plot_optimizer_trends( - exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, ) -> None: """ Plots the optimizer trends for the Experiment. @@ -309,12 +355,16 @@ def plot_optimizer_trends( (results_df, obj_cols) = expand_results_data_args(exp_data, results_df, objectives) (results_df, groupby_columns, groupby_column) = _add_groupby_desc_column(results_df) - for (objective_column, ascending) in obj_cols.items(): + for objective_column, ascending in obj_cols.items(): incumbent_column = objective_column + ".incumbent" # Determine the mean of each config trial group to match the box plots. - group_results_df = results_df.groupby(groupby_columns)[objective_column].mean()\ - .reset_index().sort_values(groupby_columns) + group_results_df = ( + results_df.groupby(groupby_columns)[objective_column] + .mean() + .reset_index() + .sort_values(groupby_columns) + ) # # Note: technically the optimizer (usually) uses the *first* result for a # given config trial group before moving on to a new config (x-axis), so @@ -352,24 +402,29 @@ def plot_optimizer_trends( ax=axis, ) - plt.yscale('log') + plt.yscale("log") plt.ylabel(objective_column.replace(ExperimentData.RESULT_COLUMN_PREFIX, "")) plt.xlabel("Config Trial Group ID, Config ID") plt.xticks(rotation=90, fontsize=8) - plt.title("Optimizer Trends for Experiment: " + exp_data.experiment_id if exp_data is not None else "") + plt.title( + "Optimizer Trends for Experiment: " + exp_data.experiment_id + if exp_data is not None + else "" + ) plt.grid() plt.show() # type: ignore[no-untyped-call] -def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, - *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - with_scatter_plot: bool = False, - **kwargs: Any, - ) -> None: +def plot_top_n_configs( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, + with_scatter_plot: bool = False, + **kwargs: Any, +) -> None: # pylint: disable=too-many-locals """ Plots the top-N configs along with the default config for the given ExperimentData. @@ -397,12 +452,16 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, top_n_config_args["results_df"] = results_df if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives - (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs(exp_data=exp_data, **top_n_config_args) + (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( + exp_data=exp_data, **top_n_config_args + ) - (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column(top_n_config_results_df) + (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column( + top_n_config_results_df + ) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 - for (orderby_col, ascending) in orderby_cols.items(): + for orderby_col, ascending in orderby_cols.items(): opt_tgt = orderby_col.replace(ExperimentData.RESULT_COLUMN_PREFIX, "") (_fig, axis) = plt.subplots() sns.violinplot( @@ -422,12 +481,12 @@ def plot_top_n_configs(exp_data: Optional[ExperimentData] = None, plt.grid() (xticks, xlabels) = plt.xticks() # default should be in the first position based on top_n_configs() return - xlabels[0] = "default" # type: ignore[call-overload] - plt.xticks(xticks, xlabels) # type: ignore[arg-type] + xlabels[0] = "default" # type: ignore[call-overload] + plt.xticks(xticks, xlabels) # type: ignore[arg-type] plt.xlabel("Config Trial Group, Config ID") plt.xticks(rotation=90) plt.ylabel(opt_tgt) - plt.yscale('log') + plt.yscale("log") extra_title = "(lower is better)" if ascending else "(lower is better)" plt.title(f"Top {top_n} configs {opt_tgt} {extra_title}") plt.show() # type: ignore[no-untyped-call] diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 40deb848fd..7275966350 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -13,10 +13,12 @@ from mlos_viz.util import expand_results_data_args -def plot(exp_data: Optional[ExperimentData] = None, *, - results_df: Optional[pandas.DataFrame] = None, - objectives: Optional[Dict[str, Literal["min", "max"]]] = None, - ) -> None: +def plot( + exp_data: Optional[ExperimentData] = None, + *, + results_df: Optional[pandas.DataFrame] = None, + objectives: Optional[Dict[str, Literal["min", "max"]]] = None, +) -> None: """ Plots the Experiment results data using dabl. @@ -40,17 +42,45 @@ def ignore_plotter_warnings() -> None: """Add some filters to ignore warnings from the plotter.""" # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Could not infer format") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated") - warnings.filterwarnings("ignore", module="dabl", category=UserWarning, - message="Missing values in target_col have been removed for regression") + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="Could not infer format" + ) + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers" + ) + warnings.filterwarnings( + "ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated" + ) + warnings.filterwarnings( + "ignore", + module="dabl", + category=UserWarning, + message="Missing values in target_col have been removed for regression", + ) from sklearn.exceptions import UndefinedMetricWarning - warnings.filterwarnings("ignore", module="sklearn", category=UndefinedMetricWarning, message="Recall is ill-defined") - warnings.filterwarnings("ignore", category=DeprecationWarning, - message="is_categorical_dtype is deprecated and will be removed in a future version.") - warnings.filterwarnings("ignore", category=DeprecationWarning, module="sklearn", - message="is_sparse is deprecated and will be removed in a future version.") + + warnings.filterwarnings( + "ignore", + module="sklearn", + category=UndefinedMetricWarning, + message="Recall is ill-defined", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + message="is_categorical_dtype is deprecated and will be removed in a future version.", + ) + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + module="sklearn", + message="is_sparse is deprecated and will be removed in a future version.", + ) from matplotlib._api.deprecation import MatplotlibDeprecationWarning - warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning, module="dabl", - message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed") + + warnings.filterwarnings( + "ignore", + category=MatplotlibDeprecationWarning, + module="dabl", + message="The legendHandles attribute was deprecated in Matplotlib 3.7 and will be removed", + ) diff --git a/mlos_viz/mlos_viz/tests/test_mlos_viz.py b/mlos_viz/mlos_viz/tests/test_mlos_viz.py index ecd072c287..6d393dca6a 100644 --- a/mlos_viz/mlos_viz/tests/test_mlos_viz.py +++ b/mlos_viz/mlos_viz/tests/test_mlos_viz.py @@ -28,5 +28,5 @@ def test_plot(mock_show: Mock, mock_boxplot: Mock, exp_data: ExperimentData) -> warnings.simplefilter("error") random.seed(42) plot(exp_data, filter_warnings=True) - assert mock_show.call_count >= 2 # from the two base plots and anything dabl did - assert mock_boxplot.call_count >= 1 # from anything dabl did + assert mock_show.call_count >= 2 # from the two base plots and anything dabl did + assert mock_boxplot.call_count >= 1 # from anything dabl did diff --git a/mlos_viz/mlos_viz/util.py b/mlos_viz/mlos_viz/util.py index deb5227bc3..cefc3080d9 100644 --- a/mlos_viz/mlos_viz/util.py +++ b/mlos_viz/mlos_viz/util.py @@ -34,7 +34,8 @@ def expand_results_data_args( Returns ------- Tuple[pandas.DataFrame, Dict[str, bool]] - The results dataframe and the objectives columns in the dataframe, plus whether or not they are in ascending order. + The results dataframe and the objectives columns in the dataframe, plus + whether or not they are in ascending order. """ # Prepare the orderby columns. if results_df is None: @@ -47,11 +48,14 @@ def expand_results_data_args( raise ValueError("Must provide either exp_data or both results_df and objectives.") objectives = exp_data.objectives objs_cols: Dict[str, bool] = {} - for (opt_tgt, opt_dir) in objectives.items(): + for opt_tgt, opt_dir in objectives.items(): if opt_dir not in ["min", "max"]: raise ValueError(f"Unexpected optimization direction for target {opt_tgt}: {opt_dir}") ascending = opt_dir == "min" - if opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) and opt_tgt in results_df.columns: + if ( + opt_tgt.startswith(ExperimentData.RESULT_COLUMN_PREFIX) + and opt_tgt in results_df.columns + ): objs_cols[opt_tgt] = ascending elif ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt in results_df.columns: objs_cols[ExperimentData.RESULT_COLUMN_PREFIX + opt_tgt] = ascending diff --git a/mlos_viz/mlos_viz/version.py b/mlos_viz/mlos_viz/version.py index 1d10835cd0..6d75e70bb5 100644 --- a/mlos_viz/mlos_viz/version.py +++ b/mlos_viz/mlos_viz/version.py @@ -5,7 +5,7 @@ """Version number for the mlos_viz package.""" # NOTE: This should be managed by bumpversion. -VERSION = '0.5.1' +VERSION = "0.5.1" if __name__ == "__main__": print(VERSION) diff --git a/mlos_viz/setup.py b/mlos_viz/setup.py index 73fd0f3c66..4f5e8677d1 100644 --- a/mlos_viz/setup.py +++ b/mlos_viz/setup.py @@ -19,15 +19,16 @@ try: ns: Dict[str, str] = {} with open(f"{PKG_NAME}/version.py", encoding="utf-8") as version_file: - exec(version_file.read(), ns) # pylint: disable=exec-used - VERSION = ns['VERSION'] + exec(version_file.read(), ns) # pylint: disable=exec-used + VERSION = ns["VERSION"] except OSError: VERSION = "0.0.1-dev" warning(f"version.py not found, using dummy VERSION={VERSION}") try: from setuptools_scm import get_version - version = get_version(root='..', relative_to=__file__, fallback_version=VERSION) + + version = get_version(root="..", relative_to=__file__, fallback_version=VERSION) if version is not None: VERSION = version except ImportError: @@ -45,22 +46,22 @@ # be duplicated for now. def _get_long_desc_from_readme(base_url: str) -> dict: pkg_dir = os.path.dirname(__file__) - readme_path = os.path.join(pkg_dir, 'README.md') + readme_path = os.path.join(pkg_dir, "README.md") if not os.path.isfile(readme_path): return { - 'long_description': 'missing', + "long_description": "missing", } - jsonc_re = re.compile(r'```jsonc') - link_re = re.compile(r'\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)') - with open(readme_path, mode='r', encoding='utf-8') as readme_fh: + jsonc_re = re.compile(r"```jsonc") + link_re = re.compile(r"\]\(([^:#)]+)(#[a-zA-Z0-9_-]+)?\)") + with open(readme_path, mode="r", encoding="utf-8") as readme_fh: lines = readme_fh.readlines() # Tweak the lexers for local expansion by pygments instead of github's. - lines = [link_re.sub(f"]({base_url}" + r'/\1\2)', line) for line in lines] + lines = [link_re.sub(f"]({base_url}" + r"/\1\2)", line) for line in lines] # Tweak source source code links. - lines = [jsonc_re.sub(r'```json', line) for line in lines] + lines = [jsonc_re.sub(r"```json", line) for line in lines] return { - 'long_description': ''.join(lines), - 'long_description_content_type': 'text/markdown', + "long_description": "".join(lines), + "long_description_content_type": "text/markdown", } @@ -68,23 +69,23 @@ def _get_long_desc_from_readme(base_url: str) -> dict: # construct special 'full' extra that adds requirements for all built-in # backend integrations and additional extra features. -extra_requires['full'] = list(set(chain(*extra_requires.values()))) +extra_requires["full"] = list(set(chain(*extra_requires.values()))) -extra_requires['full-tests'] = extra_requires['full'] + [ - 'pytest', - 'pytest-forked', - 'pytest-xdist', - 'pytest-cov', - 'pytest-local-badge', +extra_requires["full-tests"] = extra_requires["full"] + [ + "pytest", + "pytest-forked", + "pytest-xdist", + "pytest-cov", + "pytest-local-badge", ] setup( version=VERSION, install_requires=[ - 'mlos-bench==' + VERSION, - 'dabl>=0.2.6', - 'matplotlib<3.9', # FIXME: https://github.com/dabl/dabl/pull/341 + "mlos-bench==" + VERSION, + "dabl>=0.2.6", + "matplotlib<3.9", # FIXME: https://github.com/dabl/dabl/pull/341 ], extras_require=extra_requires, - **_get_long_desc_from_readme('https://github.com/microsoft/MLOS/tree/main/mlos_viz'), + **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_viz"), ) From 382086a19d4556cab479c2918c5e3dafaaacbbca Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 19:42:45 +0000 Subject: [PATCH 31/54] comments --- Makefile | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Makefile b/Makefile index 15eaa9b74c..62ea9a3359 100644 --- a/Makefile +++ b/Makefile @@ -195,6 +195,8 @@ build/docformatter.mlos_core.${CONDA_ENV_NAME}.build-stamp: $(MLOS_CORE_PYTHON_F build/docformatter.mlos_bench.${CONDA_ENV_NAME}.build-stamp: $(MLOS_BENCH_PYTHON_FILES) build/docformatter.mlos_viz.${CONDA_ENV_NAME}.build-stamp: $(MLOS_VIZ_PYTHON_FILES) +# docformatter returns non-zero when it changes anything so instead we ignore that +# return code and just have it recheck itself immediately build/docformatter.%.${CONDA_ENV_NAME}.build-stamp: $(DOCFORMATTER_COMMON_PREREQS) # Reformat python file docstrings with docformatter. conda run -n ${CONDA_ENV_NAME} docformatter --in-place $(filter %.py,$+) || true From ddd4d1edf7760b8ec9bc9ad2cf01a62abeef24ad Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 20:01:53 +0000 Subject: [PATCH 32/54] tweak for consistency --- .../apps/redis/scripts/local/process_redis_results.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py index d7f35f3d17..8b979e5014 100644 --- a/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py +++ b/mlos_bench/mlos_bench/config/environments/apps/redis/scripts/local/process_redis_results.py @@ -38,7 +38,10 @@ def _main(input_file: str, output_file: str) -> None: if __name__ == "__main__": parser = argparse.ArgumentParser(description="Post-process Redis benchmark results.") - parser.add_argument("input", help="Redis benchmark results (downloaded from a remote VM).") + parser.add_argument( + "input", + help="Redis benchmark results (downloaded from a remote VM).", + ) parser.add_argument( "output", help="Converted Redis benchmark data (to be consumed by OS Autotune framework).", From 0094a81fce60953090d247c8954ac37a0ebce4c2 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 20:30:18 +0000 Subject: [PATCH 33/54] some line by line tweaks to black reformatting rules --- .../environments/base_environment.py | 20 +++++++++++-------- .../mlos_bench/environments/composite_env.py | 12 +++++++++-- .../environments/local/local_env.py | 13 +++++++----- .../environments/local/local_fileshare_env.py | 9 ++++++--- .../environments/remote/network_env.py | 3 ++- .../environments/remote/remote_env.py | 4 +++- .../environments/remote/saas_env.py | 3 ++- .../mlos_bench/optimizers/base_optimizer.py | 10 ++++++++-- .../optimizers/convert_configspace.py | 17 +++++++++++----- .../optimizers/grid_search_optimizer.py | 4 +++- .../optimizers/mlos_core_optimizer.py | 7 +++++-- 11 files changed, 71 insertions(+), 31 deletions(-) diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 0c3300aa10..b100d3c974 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -150,7 +150,8 @@ def __init__( tunables = TunableGroups() groups = self._expand_groups( - config.get("tunable_params", []), (global_config or {}).get("tunable_params_map", {}) + config.get("tunable_params", []), + (global_config or {}).get("tunable_params_map", {}), ) _LOG.debug("Tunable groups for: '%s' :: %s", name, groups) @@ -185,7 +186,8 @@ def _validate_json_config(self, config: dict, name: str) -> None: @staticmethod def _expand_groups( - groups: Iterable[str], groups_exp: Dict[str, Union[str, Sequence[str]]] + groups: Iterable[str], + groups_exp: Dict[str, Union[str, Sequence[str]]], ) -> List[str]: """ Expand `$tunable_group` into actual names of the tunable groups. @@ -222,7 +224,8 @@ def _expand_groups( @staticmethod def _expand_vars( - params: Dict[str, TunableValue], global_config: Dict[str, TunableValue] + params: Dict[str, TunableValue], + global_config: Dict[str, TunableValue], ) -> dict: """Expand `$var` into actual values of the variables.""" return DictTemplater(params).expand_vars(extra_source_dict=global_config) @@ -277,8 +280,8 @@ def __repr__(self) -> str: def pprint(self, indent: int = 4, level: int = 0) -> str: """ - Pretty-print the environment configuration. For composite environments, print - all children environments as well. + Pretty-print the environment configuration. + For composite environments, print all children environments as well. Parameters ---------- @@ -297,9 +300,10 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ - Plug tunable values into the base config. If the tunable group is unknown, - ignore it (it might belong to another environment). This method should never - mutate the original config or the tunables. + Plug tunable values into the base config.\ + If the tunable group is unknown, ignore it (it might belong to another + environment). + This method should never mutate the original config or the tunables. Parameters ---------- diff --git a/mlos_bench/mlos_bench/environments/composite_env.py b/mlos_bench/mlos_bench/environments/composite_env.py index 6f8961ce06..e37d5273eb 100644 --- a/mlos_bench/mlos_bench/environments/composite_env.py +++ b/mlos_bench/mlos_bench/environments/composite_env.py @@ -78,13 +78,21 @@ def __init__( for child_config_file in config.get("include_children", []): for env in self._config_loader_service.load_environment_list( - child_config_file, tunables, global_config, self._const_args, self._service + child_config_file, + tunables, + global_config, + self._const_args, + self._service, ): self._add_child(env, tunables) for child_config in config.get("children", []): env = self._config_loader_service.build_environment( - child_config, tunables, global_config, self._const_args, self._service + child_config, + tunables, + global_config, + self._const_args, + self._service, ) self._add_child(env, tunables) diff --git a/mlos_bench/mlos_bench/environments/local/local_env.py b/mlos_bench/mlos_bench/environments/local/local_env.py index 071827d364..c3e81bb94d 100644 --- a/mlos_bench/mlos_bench/environments/local/local_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_env.py @@ -88,7 +88,7 @@ def __init__( def __enter__(self) -> Environment: assert self._temp_dir is None and self._temp_dir_context is None self._temp_dir_context = self._local_exec_service.temp_dir_context( - self.config.get("temp_dir") + self.config.get("temp_dir"), ) self._temp_dir = self._temp_dir_context.__enter__() return super().__enter__() @@ -194,7 +194,8 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: data = self._normalize_columns( pandas.read_csv( self._config_loader_service.resolve_path( - self._read_results_file, extra_paths=[self._temp_dir] + self._read_results_file, + extra_paths=[self._temp_dir], ), index_col=False, ) @@ -208,7 +209,6 @@ def run(self) -> Tuple[Status, datetime, Optional[Dict[str, TunableValue]]]: ) data = pandas.DataFrame([data.value.to_list()], columns=data.metric.to_list()) # Try to convert string metrics to numbers. - # type: ignore[assignment] # (false positive) data = data.apply( # type: ignore[assignment] # (false positive) pandas.to_numeric, errors="coerce", @@ -241,7 +241,8 @@ def status(self) -> Tuple[Status, datetime, List[Tuple[datetime, str, Any]]]: assert self._temp_dir is not None try: fname = self._config_loader_service.resolve_path( - self._read_telemetry_file, extra_paths=[self._temp_dir] + self._read_telemetry_file, + extra_paths=[self._temp_dir], ) # TODO: Use the timestamp of the CSV file as our status timestamp? @@ -303,7 +304,9 @@ def _local_exec(self, script: Iterable[str], cwd: Optional[str] = None) -> Tuple env_params = self._get_env_params() _LOG.info("Run script locally on: %s at %s with env %s", self, cwd, env_params) (return_code, stdout, stderr) = self._local_exec_service.local_exec( - script, env=env_params, cwd=cwd + script, + env=env_params, + cwd=cwd, ) if return_code != 0: _LOG.warning("ERROR: Local script returns code %d stderr:\n%s", return_code, stderr) diff --git a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py index 2996ea8cd2..14ba59f3f6 100644 --- a/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py +++ b/mlos_bench/mlos_bench/environments/local/local_fileshare_env.py @@ -88,7 +88,8 @@ def _template_from_to(self, config_key: str) -> List[Tuple[Template, Template]]: @staticmethod def _expand( - from_to: Iterable[Tuple[Template, Template]], params: Mapping[str, TunableValue] + from_to: Iterable[Tuple[Template, Template]], + params: Mapping[str, TunableValue], ) -> Generator[Tuple[str, str], None, None]: """ Substitute $var parameters in from/to path templates. @@ -129,7 +130,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - self._file_share_service.upload( self._params, self._config_loader_service.resolve_path( - path_from, extra_paths=[self._temp_dir] + path_from, + extra_paths=[self._temp_dir], ), path_to, ) @@ -154,7 +156,8 @@ def _download_files(self, ignore_missing: bool = False) -> None: self._params, path_from, self._config_loader_service.resolve_path( - path_to, extra_paths=[self._temp_dir] + path_to, + extra_paths=[self._temp_dir], ), ) except FileNotFoundError as ex: diff --git a/mlos_bench/mlos_bench/environments/remote/network_env.py b/mlos_bench/mlos_bench/environments/remote/network_env.py index 3f36345b58..c87ddd0899 100644 --- a/mlos_bench/mlos_bench/environments/remote/network_env.py +++ b/mlos_bench/mlos_bench/environments/remote/network_env.py @@ -111,7 +111,8 @@ def teardown(self) -> None: # Else _LOG.info("Network tear down: %s", self) (status, params) = self._network_service.deprovision_network( - self._params, ignore_errors=True + self._params, + ignore_errors=True, ) if status.is_pending(): (status, _) = self._network_service.wait_network_deployment(params, is_setup=False) diff --git a/mlos_bench/mlos_bench/environments/remote/remote_env.py b/mlos_bench/mlos_bench/environments/remote/remote_env.py index 87a76be45a..c48b84cfdd 100644 --- a/mlos_bench/mlos_bench/environments/remote/remote_env.py +++ b/mlos_bench/mlos_bench/environments/remote/remote_env.py @@ -177,7 +177,9 @@ def _remote_exec(self, script: Iterable[str]) -> Tuple[Status, datetime, Optiona env_params = self._get_env_params() _LOG.debug("Submit script: %s with %s", self, env_params) (status, output) = self._remote_exec_service.remote_exec( - script, config=self._params, env_params=env_params + script, + config=self._params, + env_params=env_params, ) _LOG.debug("Script submitted: %s %s :: %s", self, status, output) if status in {Status.PENDING, Status.SUCCEEDED}: diff --git a/mlos_bench/mlos_bench/environments/remote/saas_env.py b/mlos_bench/mlos_bench/environments/remote/saas_env.py index 0b64bc679f..5d1ba9d800 100644 --- a/mlos_bench/mlos_bench/environments/remote/saas_env.py +++ b/mlos_bench/mlos_bench/environments/remote/saas_env.py @@ -89,7 +89,8 @@ def setup(self, tunables: TunableGroups, global_config: Optional[dict] = None) - return False (status, _) = self._config_service.configure( - self._params, self._tunable_params.get_param_values() + self._params, + self._tunable_params.get_param_values(), ) if not status.is_succeeded(): return False diff --git a/mlos_bench/mlos_bench/optimizers/base_optimizer.py b/mlos_bench/mlos_bench/optimizers/base_optimizer.py index 9cdecffc81..6fa7ad87f4 100644 --- a/mlos_bench/mlos_bench/optimizers/base_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/base_optimizer.py @@ -303,14 +303,20 @@ def register( from the dataframe that's being MINIMIZED. """ _LOG.info( - "Iteration %d :: Register: %s = %s score: %s", self._iter, tunables, status, score + "Iteration %d :: Register: %s = %s score: %s", + self._iter, + tunables, + status, + score, ) if status.is_succeeded() == (score is None): # XOR raise ValueError("Status and score must be consistent.") return self._get_scores(status, score) def _get_scores( - self, status: Status, scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]] + self, + status: Status, + scores: Optional[Union[Dict[str, TunableValue], Dict[str, float]]], ) -> Optional[Dict[str, float]]: """ Extract a scalar benchmark score from the dataframe. Change the sign if we are diff --git a/mlos_bench/mlos_bench/optimizers/convert_configspace.py b/mlos_bench/mlos_bench/optimizers/convert_configspace.py index f53e308352..e4ff9897fa 100644 --- a/mlos_bench/mlos_bench/optimizers/convert_configspace.py +++ b/mlos_bench/mlos_bench/optimizers/convert_configspace.py @@ -47,7 +47,9 @@ def _normalize_weights(weights: List[float]) -> List[float]: def _tunable_to_configspace( - tunable: Tunable, group_name: Optional[str] = None, cost: int = 0 + tunable: Tunable, + group_name: Optional[str] = None, + cost: int = 0, ) -> ConfigurationSpace: """ Convert a single Tunable to an equivalent set of ConfigSpace Hyperparameter objects, @@ -88,11 +90,13 @@ def _tunable_to_configspace( distribution = Uniform() elif tunable.distribution == "normal": distribution = Normal( - mu=tunable.distribution_params["mu"], sigma=tunable.distribution_params["sigma"] + mu=tunable.distribution_params["mu"], + sigma=tunable.distribution_params["sigma"], ) elif tunable.distribution == "beta": distribution = Beta( - alpha=tunable.distribution_params["alpha"], beta=tunable.distribution_params["beta"] + alpha=tunable.distribution_params["alpha"], + beta=tunable.distribution_params["beta"], ) elif tunable.distribution is not None: raise TypeError(f"Invalid Distribution Type: {tunable.distribution}") @@ -170,7 +174,8 @@ def _tunable_to_configspace( def tunable_groups_to_configspace( - tunables: TunableGroups, seed: Optional[int] = None + tunables: TunableGroups, + seed: Optional[int] = None, ) -> ConfigurationSpace: """ Convert TunableGroups to hyperparameters in ConfigurationSpace. @@ -194,7 +199,9 @@ def tunable_groups_to_configspace( prefix="", delimiter="", configuration_space=_tunable_to_configspace( - tunable, group.name, group.get_current_cost() + tunable, + group.name, + group.get_current_cost(), ), ) return space diff --git a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py index 568cfff43f..8bcd090415 100644 --- a/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/grid_search_optimizer.py @@ -60,7 +60,9 @@ def _sanity_check(self) -> None: ) if size > self._max_iter: _LOG.warning( - "Grid search size %d, is greater than max iterations %d", size, self._max_iter + "Grid search size %d, is greater than max iterations %d", + size, + self._max_iter, ) def _get_grid(self) -> Tuple[Tuple[str, ...], Dict[Tuple[TunableValue, ...], None]]: diff --git a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py index dfaa345548..e8c1195421 100644 --- a/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/mlos_core_optimizer.py @@ -202,7 +202,9 @@ def register( score: Optional[Dict[str, TunableValue]] = None, ) -> Optional[Dict[str, float]]: registered_score = super().register( - tunables, status, score + tunables, + status, + score, ) # Sign-adjusted for MINIMIZATION if status.is_completed(): assert registered_score is not None @@ -211,7 +213,8 @@ def register( # TODO: Specify (in the config) which metrics to pass to the optimizer. # Issue: https://github.com/microsoft/MLOS/issues/745 self._opt.register( - configs=df_config, scores=pd.DataFrame([registered_score], dtype=float) + configs=df_config, + scores=pd.DataFrame([registered_score], dtype=float), ) return registered_score From dd67302a381faeae770d251b33f7bd44faf804f6 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 20:39:36 +0000 Subject: [PATCH 34/54] let docstrings get reformatted again --- .../mlos_bench/environments/base_environment.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index b100d3c974..2c51a5ee8d 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -280,8 +280,8 @@ def __repr__(self) -> str: def pprint(self, indent: int = 4, level: int = 0) -> str: """ - Pretty-print the environment configuration. - For composite environments, print all children environments as well. + Pretty-print the environment configuration. For composite environments, print + all children environments as well. Parameters ---------- @@ -300,10 +300,9 @@ def pprint(self, indent: int = 4, level: int = 0) -> str: def _combine_tunables(self, tunables: TunableGroups) -> Dict[str, TunableValue]: """ - Plug tunable values into the base config.\ - If the tunable group is unknown, ignore it (it might belong to another - environment). - This method should never mutate the original config or the tunables. + Plug tunable values into the base config. If the tunable group is unknown, + ignore it (it might belong to another environment). This method should never + mutate the original config or the tunables. Parameters ---------- From 8f2efb5c28e1dee4239a8ac178697fd666eba03b Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 20:49:44 +0000 Subject: [PATCH 35/54] more line by line tweaks --- mlos_bench/mlos_bench/schedulers/base_scheduler.py | 4 +++- mlos_bench/mlos_bench/schedulers/sync_scheduler.py | 5 ++--- mlos_bench/mlos_bench/services/local/local_exec.py | 10 ++++++++-- .../mlos_bench/services/local/temp_dir_context.py | 8 ++++++-- .../mlos_bench/services/remote/azure/azure_auth.py | 4 +++- 5 files changed, 22 insertions(+), 9 deletions(-) diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index c268aab14c..7b0e526608 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -60,7 +60,9 @@ def __init__( """ self.global_config = global_config config = merge_parameters( - dest=config.copy(), source=global_config, required_keys=["experiment_id", "trial_id"] + dest=config.copy(), + source=global_config, + required_keys=["experiment_id", "trial_id"], ) self._experiment_id = config["experiment_id"].strip() diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 96cf15cdc9..e56d15ca17 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -49,9 +49,8 @@ def run_trial(self, trial: Storage.Trial) -> None: trial.update(Status.FAILED, datetime.now(UTC)) return - (status, timestamp, results) = ( - self.environment.run() - ) # Block and wait for the final result. + # Block and wait for the final result. + (status, timestamp, results) = self.environment.run() _LOG.info("Results: %s :: %s\n%s", trial.tunables, status, results) # In async mode (TODO), poll the environment for status and telemetry diff --git a/mlos_bench/mlos_bench/services/local/local_exec.py b/mlos_bench/mlos_bench/services/local/local_exec.py index f595c75a89..a1339312a6 100644 --- a/mlos_bench/mlos_bench/services/local/local_exec.py +++ b/mlos_bench/mlos_bench/services/local/local_exec.py @@ -102,7 +102,10 @@ def __init__( New methods to register with the service. """ super().__init__( - config, global_config, parent, self.merge_methods(methods, [self.local_exec]) + config, + global_config, + parent, + self.merge_methods(methods, [self.local_exec]), ) self.abort_on_error = self.config.get("abort_on_error", True) @@ -180,7 +183,10 @@ def _resolve_cmdline_script_path(self, subcmd_tokens: List[str]) -> List[str]: return subcmd_tokens def _local_exec_script( - self, script_line: str, env_params: Optional[Mapping[str, "TunableValue"]], cwd: str + self, + script_line: str, + env_params: Optional[Mapping[str, "TunableValue"]], + cwd: str, ) -> Tuple[int, str, str]: """ Execute the script from `script_path` in a local process. diff --git a/mlos_bench/mlos_bench/services/local/temp_dir_context.py b/mlos_bench/mlos_bench/services/local/temp_dir_context.py index 06bb32bc5f..e65a45934b 100644 --- a/mlos_bench/mlos_bench/services/local/temp_dir_context.py +++ b/mlos_bench/mlos_bench/services/local/temp_dir_context.py @@ -50,7 +50,10 @@ def __init__( New methods to register with the service. """ super().__init__( - config, global_config, parent, self.merge_methods(methods, [self.temp_dir_context]) + config, + global_config, + parent, + self.merge_methods(methods, [self.temp_dir_context]), ) self._temp_dir = self.config.get("temp_dir") if self._temp_dir: @@ -61,7 +64,8 @@ def __init__( _LOG.info("%s: temp dir: %s", self, self._temp_dir) def temp_dir_context( - self, path: Optional[str] = None + self, + path: Optional[str] = None, ) -> Union[TemporaryDirectory, nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py index bded5fb99e..619e8eed90 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_auth.py @@ -109,7 +109,9 @@ def _init_sp(self) -> None: # Reauthenticate as the service principal. self._cred = azure_id.CertificateCredential( - tenant_id=tenant_id, client_id=sp_client_id, certificate_data=cert_bytes + tenant_id=tenant_id, + client_id=sp_client_id, + certificate_data=cert_bytes, ) def get_access_token(self) -> str: From a6aa7bf2dc91421e3a2841ef81a859ddc6a1e56c Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Mon, 8 Jul 2024 20:51:58 +0000 Subject: [PATCH 36/54] comments --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6865bf3d71..f70030a576 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,3 +14,5 @@ black = true style = "numpy" pre-summary-newline = true close-quotes-on-newline = true + +# TODO: move pylintrc and some setup.cfg configs here From 8cea74d3a1df1071b5f804b0fc0a32427fb07bc9 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 16:00:24 +0000 Subject: [PATCH 37/54] more line by line tweaks --- .../remote/azure/azure_deployment_services.py | 33 +++++--- .../services/remote/azure/azure_fileshare.py | 6 +- .../remote/azure/azure_network_services.py | 12 +-- .../services/remote/azure/azure_saas.py | 22 ++++-- .../remote/azure/azure_vm_services.py | 75 ++++++++++--------- 5 files changed, 84 insertions(+), 64 deletions(-) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index 24c4242e8f..dfa1a60555 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -37,11 +37,11 @@ class AzureDeploymentService(Service, metaclass=abc.ABCMeta): _URL_DEPLOY = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Resources" - + "/deployments/{deployment_name}" - + "?api-version=2022-05-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Resources" + "/deployments/{deployment_name}" + "?api-version=2022-05-01" ) def __init__( @@ -289,7 +289,10 @@ def _wait_deployment(self, params: dict, *, is_setup: bool) -> Tuple[Status, dic return self._wait_while(self._check_deployment, Status.PENDING, params) def _wait_while( - self, func: Callable[[dict], Tuple[Status, dict]], loop_status: Status, params: dict + self, + func: Callable[[dict], Tuple[Status, dict]], + loop_status: Status, + params: dict, ) -> Tuple[Status, dict]: """ Invoke `func` periodically while the status is equal to `loop_status`. Return @@ -312,7 +315,9 @@ def _wait_while( """ params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] + dest=self.config.copy(), + source=params, + required_keys=["deploymentName"], ) poll_period = params.get("pollInterval", self._poll_interval) @@ -347,9 +352,8 @@ def _wait_while( _LOG.warning("Request timed out: %s", params) return (Status.TIMED_OUT, {}) - def _check_deployment( - self, params: dict - ) -> Tuple[Status, dict]: # pylint: disable=too-many-return-statements + def _check_deployment(self, params: dict) -> Tuple[Status, dict]: + # pylint: disable=too-many-return-statements """ Check if Azure deployment exists. Return SUCCEEDED if true, PENDING otherwise. @@ -436,7 +440,9 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: raise ValueError(f"Missing deployment template: {self}") params = self._set_default_params(params) config = merge_parameters( - dest=self.config.copy(), source=params, required_keys=["deploymentName"] + dest=self.config.copy(), + source=params, + required_keys=["deploymentName"], ) _LOG.info("Deploy: %s :: %s", config["deploymentName"], params) @@ -470,7 +476,10 @@ def _provision_resource(self, params: dict) -> Tuple[Status, dict]: _LOG.debug("Request: PUT %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.put( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + url, + json=json_req, + headers=self._get_headers(), + timeout=self._request_timeout, ) if _LOG.isEnabledFor(logging.DEBUG): diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index ddd41afcc2..0f09694489 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -98,7 +98,11 @@ def download( raise FileNotFoundError(f"Cannot download: {remote_path}") from ex def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True + self, + params: dict, + local_path: str, + remote_path: str, + recursive: bool = True, ) -> None: super().upload(params, local_path, remote_path, recursive) self._upload(local_path, remote_path, recursive, set()) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 9c66fc7b0c..2ca6130a2f 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -29,12 +29,12 @@ class AzureNetworkService(AzureDeploymentService, SupportsNetworkProvisioning): # From: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/virtual-networks?view=rest-virtualnetwork-2023-05-01 # pylint: disable=line-too-long # noqa _URL_DEPROVISION = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Network" - + "/virtualNetwork/{vnet_name}" - + "/delete" - + "?api-version=2023-05-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Network" + "/virtualNetwork/{vnet_name}" + "/delete" + "?api-version=2023-05-01" ) def __init__( diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py index 9a2081c90f..042e599f0b 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_saas.py @@ -29,12 +29,12 @@ class AzureSaaSConfigService(Service, SupportsRemoteConfig): _URL_CONFIGURE = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/{provider}" - + "/{server_type}/{vm_name}" - + "/{update}" - + "?api-version={api_version}" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/{provider}" + "/{server_type}/{vm_name}" + "/{update}" + "?api-version={api_version}" ) def __init__( @@ -183,7 +183,10 @@ def _get_headers(self) -> dict: return self._parent.get_auth_headers() def _config_one( - self, config: Dict[str, Any], param_name: str, param_value: Any + self, + config: Dict[str, Any], + param_name: str, + param_value: Any, ) -> Tuple[Status, dict]: """ Update a single parameter of the Azure DB service. @@ -270,7 +273,10 @@ def _config_batch(self, config: Dict[str, Any], params: Dict[str, Any]) -> Tuple } _LOG.debug("Request: POST %s", url) response = requests.post( - url, headers=self._get_headers(), json=json_req, timeout=self._request_timeout + url, + headers=self._get_headers(), + json=json_req, + timeout=self._request_timeout, ) _LOG.debug("Response: %s :: %s", response, response.text) if response.status_code == 504: diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 06e71780e8..2bc2bb35cd 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -41,34 +41,34 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/start _URL_START = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/start" - + "?api-version=2022-03-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/start" + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/power-off _URL_STOP = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/powerOff" - + "?api-version=2022-03-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/powerOff" + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/deallocate _URL_DEALLOCATE = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/deallocate" - + "?api-version=2022-03-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/deallocate" + "?api-version=2022-03-01" ) # TODO: This is probably the more correct URL to use for the deprovision operation. @@ -79,35 +79,35 @@ class AzureVMService( # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/delete # _URL_DEPROVISION = ( - # "https://management.azure.com" + - # "/subscriptions/{subscription}" + - # "/resourceGroups/{resource_group}" + - # "/providers/Microsoft.Compute" + - # "/virtualMachines/{vm_name}" + - # "/delete" + + # "https://management.azure.com" + # "/subscriptions/{subscription}" + # "/resourceGroups/{resource_group}" + # "/providers/Microsoft.Compute" + # "/virtualMachines/{vm_name}" + # "/delete" # "?api-version=2022-03-01" # ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/restart _URL_REBOOT = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/restart" - + "?api-version=2022-03-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/restart" + "?api-version=2022-03-01" ) # From: https://docs.microsoft.com/en-us/rest/api/compute/virtual-machines/run-command _URL_REXEC_RUN = ( "https://management.azure.com" - + "/subscriptions/{subscription}" - + "/resourceGroups/{resource_group}" - + "/providers/Microsoft.Compute" - + "/virtualMachines/{vm_name}" - + "/runCommand" - + "?api-version=2022-03-01" + "/subscriptions/{subscription}" + "/resourceGroups/{resource_group}" + "/providers/Microsoft.Compute" + "/virtualMachines/{vm_name}" + "/runCommand" + "?api-version=2022-03-01" ) def __init__( @@ -181,7 +181,8 @@ def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self- if "vmName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vmName']}-deployment" _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + "deploymentName missing from params. Defaulting to '%s'.", + params["deploymentName"], ) return params From 839b96a187a6aba2cafd53d5747605412f233658 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:36:40 +0000 Subject: [PATCH 38/54] tweaks --- .../services/remote/ssh/ssh_fileshare.py | 26 ++++++++++++++++--- .../services/remote/ssh/ssh_host_service.py | 21 ++++++++++++--- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py index db44a5411d..383fcfbd20 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_fileshare.py @@ -76,7 +76,11 @@ async def _start_file_copy( return await scp(srcpaths=srcpaths, dstpath=dstpath, recurse=recursive, preserve=True) def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True + self, + params: dict, + remote_path: str, + local_path: str, + recursive: bool = True, ) -> None: params = merge_parameters( dest=self.config.copy(), @@ -87,13 +91,23 @@ def download( ) super().download(params, remote_path, local_path, recursive) file_copy_future = self._run_coroutine( - self._start_file_copy(params, CopyMode.DOWNLOAD, local_path, remote_path, recursive) + self._start_file_copy( + params, + CopyMode.DOWNLOAD, + local_path, + remote_path, + recursive, + ) ) try: file_copy_future.result() except (OSError, SFTPError) as ex: _LOG.error( - "Failed to download %s to %s from %s: %s", remote_path, local_path, params, ex + "Failed to download %s to %s from %s: %s", + remote_path, + local_path, + params, + ex, ) if isinstance(ex, SFTPNoSuchFile) or ( isinstance(ex, SFTPFailure) @@ -108,7 +122,11 @@ def download( raise ex def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True + self, + params: dict, + local_path: str, + remote_path: str, + recursive: bool = True, ) -> None: params = merge_parameters( dest=self.config.copy(), diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py index db7dbdffe0..36f1f7866b 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_host_service.py @@ -67,7 +67,10 @@ def __init__( self._shell = self.config.get("ssh_shell", "/bin/bash") async def _run_cmd( - self, params: dict, script: Iterable[str], env_params: dict + self, + params: dict, + script: Iterable[str], + env_params: dict, ) -> SSHCompletedProcess: """ Runs a command asynchronously on a host via SSH. @@ -100,11 +103,17 @@ async def _run_cmd( script_str = "\n".join(script_lines) _LOG.debug("Running script on %s:\n%s", connection, script_str) return await connection.run( - script_str, check=False, timeout=self._request_timeout, env=env_params + script_str, + check=False, + timeout=self._request_timeout, + env=env_params, ) def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict + self, + script: Iterable[str], + config: dict, + env_params: dict, ) -> Tuple["Status", dict]: """ Start running a command on remote host OS. @@ -135,7 +144,11 @@ def remote_exec( ], ) config["asyncRemoteExecResultsFuture"] = self._run_coroutine( - self._run_cmd(config, script, env_params) + self._run_cmd( + config, + script, + env_params, + ) ) return (Status.PENDING, config) From d811da57ec024b4cafa6d1c47ddcd72d9754fb20 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:38:49 +0000 Subject: [PATCH 39/54] tweaks --- .../mlos_bench/services/types/fileshare_type.py | 12 ++++++++++-- .../mlos_bench/services/types/local_exec_type.py | 3 ++- .../services/types/network_provisioner_type.py | 4 +++- .../mlos_bench/services/types/remote_exec_type.py | 5 ++++- 4 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mlos_bench/mlos_bench/services/types/fileshare_type.py b/mlos_bench/mlos_bench/services/types/fileshare_type.py index c2ff153ac7..c69516992b 100644 --- a/mlos_bench/mlos_bench/services/types/fileshare_type.py +++ b/mlos_bench/mlos_bench/services/types/fileshare_type.py @@ -12,7 +12,11 @@ class SupportsFileShareOps(Protocol): """Protocol interface for file share operations.""" def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True + self, + params: dict, + remote_path: str, + local_path: str, + recursive: bool = True, ) -> None: """ Downloads contents from a remote share path to a local path. @@ -32,7 +36,11 @@ def download( """ def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True + self, + params: dict, + local_path: str, + remote_path: str, + recursive: bool = True, ) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/types/local_exec_type.py b/mlos_bench/mlos_bench/services/types/local_exec_type.py index 9c4d2dc224..d0c8c357f0 100644 --- a/mlos_bench/mlos_bench/services/types/local_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/local_exec_type.py @@ -58,7 +58,8 @@ def local_exec( """ def temp_dir_context( - self, path: Optional[str] = None + self, + path: Optional[str] = None, ) -> Union[tempfile.TemporaryDirectory, contextlib.nullcontext]: """ Create a temp directory or use the provided path. diff --git a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py index 19e7b16350..3525fbdee1 100644 --- a/mlos_bench/mlos_bench/services/types/network_provisioner_type.py +++ b/mlos_bench/mlos_bench/services/types/network_provisioner_type.py @@ -54,7 +54,9 @@ def wait_network_deployment(self, params: dict, *, is_setup: bool) -> Tuple["Sta """ def deprovision_network( - self, params: dict, ignore_errors: bool = True + self, + params: dict, + ignore_errors: bool = True, ) -> Tuple["Status", dict]: """ Deprovisions the Network by deleting it. diff --git a/mlos_bench/mlos_bench/services/types/remote_exec_type.py b/mlos_bench/mlos_bench/services/types/remote_exec_type.py index dd105f7a41..b6285a8f96 100644 --- a/mlos_bench/mlos_bench/services/types/remote_exec_type.py +++ b/mlos_bench/mlos_bench/services/types/remote_exec_type.py @@ -19,7 +19,10 @@ class SupportsRemoteExec(Protocol): """ def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict + self, + script: Iterable[str], + config: dict, + env_params: dict, ) -> Tuple["Status", dict]: """ Run a command on remote host OS. From a31e8422f42cba7ce7f26274d3dcee1ec6b3db82 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:41:33 +0000 Subject: [PATCH 40/54] tweaks --- mlos_bench/mlos_bench/services/base_fileshare.py | 12 ++++++++++-- mlos_bench/mlos_bench/services/config_persistence.py | 10 +++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlos_bench/mlos_bench/services/base_fileshare.py b/mlos_bench/mlos_bench/services/base_fileshare.py index c941e0b132..75ff3d2408 100644 --- a/mlos_bench/mlos_bench/services/base_fileshare.py +++ b/mlos_bench/mlos_bench/services/base_fileshare.py @@ -49,7 +49,11 @@ def __init__( @abstractmethod def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True + self, + params: dict, + remote_path: str, + local_path: str, + recursive: bool = True, ) -> None: """ Downloads contents from a remote share path to a local path. @@ -78,7 +82,11 @@ def download( @abstractmethod def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True + self, + params: dict, + local_path: str, + remote_path: str, + recursive: bool = True, ) -> None: """ Uploads contents from a local path to remote share path. diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index 2a90203fd1..ee1b7f7902 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -192,7 +192,7 @@ def load_config( ) raise ValueError( f"Failed to validate config {json_file_name} against " - + f"schema type {schema_type.name} at {schema_type.value}" + f"schema type {schema_type.name} at {schema_type.value}" ) from ex if isinstance(config, dict) and config.get("$schema"): # Remove $schema attributes from the config after we've validated @@ -382,9 +382,7 @@ def build_scheduler( """ (class_name, class_config) = self.prepare_class_load(config, global_config) # pylint: disable=import-outside-toplevel - from mlos_bench.schedulers.base_scheduler import ( - Scheduler, - ) + from mlos_bench.schedulers.base_scheduler import Scheduler inst = instantiate_from_config( Scheduler, # type: ignore[type-abstract] @@ -670,7 +668,9 @@ def load_services( return service def _load_tunables( - self, json_file_names: Iterable[str], parent: TunableGroups + self, + json_file_names: Iterable[str], + parent: TunableGroups, ) -> TunableGroups: """ Load a collection of tunable parameters from JSON files into the parent From ef71a8b88e659b1450efdeb427960bde2a01a721 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:46:26 +0000 Subject: [PATCH 41/54] tweaks --- mlos_bench/mlos_bench/storage/sql/common.py | 10 ++++++++-- mlos_bench/mlos_bench/storage/sql/schema.py | 18 +++++++++++++----- mlos_bench/mlos_bench/storage/sql/storage.py | 5 ++++- mlos_bench/mlos_bench/storage/sql/trial.py | 14 ++++++++++---- .../mlos_bench/storage/sql/trial_data.py | 4 +++- .../sql/tunable_config_trial_group_data.py | 15 ++++++++++----- 6 files changed, 48 insertions(+), 18 deletions(-) diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index 5fdc6c0731..e9cc9ed4b0 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -16,7 +16,10 @@ def get_trials( - engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None, ) -> Dict[int, TrialData]: """ Gets TrialData for the given experiment_data and optionally additionally restricted @@ -61,7 +64,10 @@ def get_trials( def get_results_df( - engine: Engine, schema: DbSchema, experiment_id: str, tunable_config_id: Optional[int] = None + engine: Engine, + schema: DbSchema, + experiment_id: str, + tunable_config_id: Optional[int] = None, ) -> pandas.DataFrame: """ Gets TrialData for the given experiment_data and optionally additionally restricted diff --git a/mlos_bench/mlos_bench/storage/sql/schema.py b/mlos_bench/mlos_bench/storage/sql/schema.py index 717dc70c2a..3900568b75 100644 --- a/mlos_bench/mlos_bench/storage/sql/schema.py +++ b/mlos_bench/mlos_bench/storage/sql/schema.py @@ -105,7 +105,11 @@ def __init__(self, engine: Engine): ) else: col_config_id = Column( - "config_id", Integer, nullable=False, primary_key=True, autoincrement=True + "config_id", + Integer, + nullable=False, + primary_key=True, + autoincrement=True, ) self.config = Table( @@ -153,7 +157,8 @@ def __init__(self, engine: Engine): Column("param_value", String(self._PARAM_VALUE_LEN)), PrimaryKeyConstraint("exp_id", "trial_id", "param_id"), ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id], ), ) @@ -166,7 +171,8 @@ def __init__(self, engine: Engine): Column("status", String(self._STATUS_LEN), nullable=False), UniqueConstraint("exp_id", "trial_id", "ts"), ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id], ), ) @@ -179,7 +185,8 @@ def __init__(self, engine: Engine): Column("metric_value", String(self._METRIC_VALUE_LEN)), PrimaryKeyConstraint("exp_id", "trial_id", "metric_id"), ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id], ), ) @@ -193,7 +200,8 @@ def __init__(self, engine: Engine): Column("metric_value", String(self._METRIC_VALUE_LEN)), UniqueConstraint("exp_id", "trial_id", "ts", "metric_id"), ForeignKeyConstraint( - ["exp_id", "trial_id"], [self.trial.c.exp_id, self.trial.c.trial_id] + ["exp_id", "trial_id"], + [self.trial.c.exp_id, self.trial.c.trial_id], ), ) diff --git a/mlos_bench/mlos_bench/storage/sql/storage.py b/mlos_bench/mlos_bench/storage/sql/storage.py index f3a317db59..6b6d11e699 100644 --- a/mlos_bench/mlos_bench/storage/sql/storage.py +++ b/mlos_bench/mlos_bench/storage/sql/storage.py @@ -24,7 +24,10 @@ class SqlStorage(Storage): """An implementation of the Storage interface using SQLAlchemy backend.""" def __init__( - self, config: dict, global_config: Optional[dict] = None, service: Optional[Service] = None + self, + config: dict, + global_config: Optional[dict] = None, + service: Optional[Service] = None, ): super().__init__(config, global_config, service) lazy_schema_create = self._config.pop("lazy_schema_create", False) diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 13233fd9a3..006f7761b8 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -75,8 +75,8 @@ def update( if cur_status.rowcount not in {1, -1}: _LOG.warning("Trial %s :: update failed: %s", self, status) raise RuntimeError( - f"Failed to update the status of the trial {self} to {status}." - + f" ({cur_status.rowcount} rows)" + f"Failed to update the status of the trial {self} to {status}. " + f"({cur_status.rowcount} rows)" ) if metrics: conn.execute( @@ -119,7 +119,10 @@ def update( return metrics def update_telemetry( - self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + self, + status: Status, + timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]], ) -> None: super().update_telemetry(status, timestamp, metrics) # Make sure to convert the timestamp to UTC before storing it in the database. @@ -165,5 +168,8 @@ def _update_status(self, conn: Connection, status: Status, timestamp: datetime) ) except IntegrityError as ex: _LOG.warning( - "Status with that timestamp already exists: %s %s :: %s", self, timestamp, ex + "Status with that timestamp already exists: %s %s :: %s", + self, + timestamp, + ex, ) diff --git a/mlos_bench/mlos_bench/storage/sql/trial_data.py b/mlos_bench/mlos_bench/storage/sql/trial_data.py index 690492585b..40362b25fd 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial_data.py +++ b/mlos_bench/mlos_bench/storage/sql/trial_data.py @@ -56,7 +56,9 @@ def tunable_config(self) -> TunableConfigData: Note: this corresponds to the Trial object's "tunables" property. """ return TunableConfigSqlData( - engine=self._engine, schema=self._schema, tunable_config_id=self._tunable_config_id + engine=self._engine, + schema=self._schema, + tunable_config_id=self._tunable_config_id, ) @property diff --git a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py index 31a6df5879..5069e435b2 100644 --- a/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py +++ b/mlos_bench/mlos_bench/storage/sql/tunable_config_trial_group_data.py @@ -69,9 +69,8 @@ def _get_tunable_config_trial_group_id(self) -> int: ) row = tunable_config_trial_group.fetchone() assert row is not None - return row._tuple()[ - 0 - ] # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + # pylint: disable=protected-access # following DeprecationWarning in sqlalchemy + return row._tuple()[0] @property def tunable_config(self) -> TunableConfigData: @@ -93,11 +92,17 @@ def trials(self) -> Dict[int, "TrialData"]: A dictionary of the trials' data, keyed by trial id. """ return common.get_trials( - self._engine, self._schema, self._experiment_id, self._tunable_config_id + self._engine, + self._schema, + self._experiment_id, + self._tunable_config_id, ) @property def results_df(self) -> pandas.DataFrame: return common.get_results_df( - self._engine, self._schema, self._experiment_id, self._tunable_config_id + self._engine, + self._schema, + self._experiment_id, + self._tunable_config_id, ) From 0f3dbf12d8b13000d90901d5afb2012f3818693d Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:48:54 +0000 Subject: [PATCH 42/54] tweaks --- mlos_bench/mlos_bench/storage/base_storage.py | 14 +++++++++++--- mlos_bench/mlos_bench/storage/storage_factory.py | 4 +++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index 41c9df0e5e..cd529c730c 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -167,7 +167,9 @@ def __exit__( else: assert exc_type and exc_val _LOG.warning( - "Finishing experiment: %s", self, exc_info=(exc_type, exc_val, exc_tb) + "Finishing experiment: %s", + self, + exc_info=(exc_type, exc_val, exc_tb), ) assert self._in_context self._teardown(is_ok) @@ -279,7 +281,10 @@ def load( @abstractmethod def pending_trials( - self, timestamp: datetime, *, running: bool + self, + timestamp: datetime, + *, + running: bool, ) -> Iterator["Storage.Trial"]: """ Return an iterator over the pending trials that are scheduled to run on or @@ -430,7 +435,10 @@ def update( @abstractmethod def update_telemetry( - self, status: Status, timestamp: datetime, metrics: List[Tuple[datetime, str, Any]] + self, + status: Status, + timestamp: datetime, + metrics: List[Tuple[datetime, str, Any]], ) -> None: """ Save the experiment's telemetry data and intermediate status. diff --git a/mlos_bench/mlos_bench/storage/storage_factory.py b/mlos_bench/mlos_bench/storage/storage_factory.py index 2de66a9aab..ea0201717d 100644 --- a/mlos_bench/mlos_bench/storage/storage_factory.py +++ b/mlos_bench/mlos_bench/storage/storage_factory.py @@ -12,7 +12,9 @@ def from_config( - config_file: str, global_configs: Optional[List[str]] = None, **kwargs: Any + config_file: str, + global_configs: Optional[List[str]] = None, + **kwargs: Any, ) -> Storage: """ Create a new storage object from JSON5 config file. From 3637e258c38a687712f6855def1fb4ff9a039d0f Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 19:57:45 +0000 Subject: [PATCH 43/54] tweaks --- .../cli/test_load_cli_config_examples.py | 28 +++++++++++++------ .../test_load_environment_config_examples.py | 20 +++++++++---- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py index 7add370011..3db11e6cb2 100644 --- a/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/cli/test_load_cli_config_examples.py @@ -42,9 +42,15 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = [ *locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, + ), + *locate_config_examples( + BUILTIN_TEST_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, ), - *locate_config_examples(BUILTIN_TEST_CONFIG_PATH, CONFIG_TYPE, filter_configs), ] assert configs @@ -52,7 +58,8 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.skip(reason="Use full Launcher test (below) instead now.") @pytest.mark.parametrize("config_path", configs) def test_load_cli_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: # pragma: no cover """Tests loading a config example.""" # pylint: disable=too-complex @@ -101,7 +108,8 @@ def test_load_cli_config_examples( @pytest.mark.parametrize("config_path", configs) def test_load_cli_config_examples_via_launcher( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a config example via the Launcher.""" config = config_loader_service.load_config(config_path, ConfigSchema.CLI) @@ -154,28 +162,32 @@ def test_load_cli_config_examples_via_launcher( assert isinstance(launcher.environment, Environment) env_config = launcher.config_loader.load_config( - config["environment"], ConfigSchema.ENVIRONMENT + config["environment"], + ConfigSchema.ENVIRONMENT, ) assert check_class_name(launcher.environment, env_config["class"]) assert isinstance(launcher.optimizer, Optimizer) if "optimizer" in config: opt_config = launcher.config_loader.load_config( - config["optimizer"], ConfigSchema.OPTIMIZER + config["optimizer"], + ConfigSchema.OPTIMIZER, ) assert check_class_name(launcher.optimizer, opt_config["class"]) assert isinstance(launcher.storage, Storage) if "storage" in config: storage_config = launcher.config_loader.load_config( - config["storage"], ConfigSchema.STORAGE + config["storage"], + ConfigSchema.STORAGE, ) assert check_class_name(launcher.storage, storage_config["class"]) assert isinstance(launcher.scheduler, Scheduler) if "scheduler" in config: scheduler_config = launcher.config_loader.load_config( - config["scheduler"], ConfigSchema.SCHEDULER + config["scheduler"], + ConfigSchema.SCHEDULER, ) assert check_class_name(launcher.scheduler, scheduler_config["class"]) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index c7d0f9ba44..51814e8a1b 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -41,7 +41,8 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) def test_load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading an environment config example.""" envs = load_environment_config_examples(config_loader_service, config_path) @@ -51,12 +52,14 @@ def test_load_environment_config_examples( def load_environment_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> List[Environment]: """Loads an environment config example.""" # Make sure that any "required_args" are provided. global_config = config_loader_service.load_config( - "experiments/experiment_test_config.jsonc", ConfigSchema.GLOBALS + "experiments/experiment_test_config.jsonc", + ConfigSchema.GLOBALS, ) global_config.setdefault("trial_id", 1) # normally populated by Launcher @@ -74,16 +77,21 @@ def load_environment_config_examples( for mock_service_config_path in mock_service_configs: mock_service_config = config_loader_service.load_config( - mock_service_config_path, ConfigSchema.SERVICE + mock_service_config_path, + ConfigSchema.SERVICE, ) config_loader_service.register( config_loader_service.build_service( - config=mock_service_config, parent=config_loader_service + config=mock_service_config, + parent=config_loader_service, ).export() ) envs = config_loader_service.load_environment_list( - config_path, tunable_groups, global_config, service=config_loader_service + config_path, + tunable_groups, + global_config, + service=config_loader_service, ) return envs From b062a84e2843f8d0c78e9b7781aeff5e426f597b Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 20:18:10 +0000 Subject: [PATCH 44/54] tweawks --- .../test_load_environment_config_examples.py | 11 ++++++----- .../globals/test_load_global_config_examples.py | 3 ++- .../optimizers/test_load_optimizer_config_examples.py | 7 +++++-- .../tests/config/schemas/cli/test_cli_schemas.py | 3 ++- .../schemas/environments/test_environment_schemas.py | 6 ++++-- 5 files changed, 19 insertions(+), 11 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index 51814e8a1b..c4a29d8984 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -97,14 +97,16 @@ def load_environment_config_examples( composite_configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, "environments/root/" + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + "environments/root/", ) assert composite_configs @pytest.mark.parametrize("config_path", composite_configs) def test_load_composite_env_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a composite env config example.""" envs = load_environment_config_examples(config_loader_service, config_path) @@ -124,9 +126,8 @@ def test_load_composite_env_config_examples( (composite_tunable, composite_group) = composite_env.tunable_params.get_tunable( child_tunable ) - assert ( - child_tunable is composite_tunable - ) # Check that the tunables are the same object. + # Check that the tunables are the same object. + assert child_tunable is composite_tunable if child_group.name not in checked_child_env_groups: assert child_group is composite_group checked_child_env_groups.add(child_group.name) diff --git a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py index c7525a2960..5940962478 100644 --- a/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/globals/test_load_global_config_examples.py @@ -52,7 +52,8 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) def test_load_globals_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.GLOBALS) diff --git a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py index e08f4d593b..4feefb8440 100644 --- a/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/optimizers/test_load_optimizer_config_examples.py @@ -29,14 +29,17 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, ) assert configs @pytest.mark.parametrize("config_path", configs) def test_load_optimizer_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.OPTIMIZER) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py index 404602b724..a47395e2d2 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/cli/test_cli_schemas.py @@ -43,7 +43,8 @@ def test_cli_configs_with_extra_param(test_case_name: str) -> None: certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.CLI + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.CLI, ) if TEST_CASES.by_path[test_case_name].test_case_type != "bad": # Unified schema has a hard time validating bad configs, so we skip it. diff --git a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py index 3e9abdbb90..3819f1848e 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/environments/test_environment_schemas.py @@ -81,8 +81,10 @@ def test_environment_configs_with_extra_param(test_case_name: str) -> None: in certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.ENVIRONMENT + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.ENVIRONMENT, ) check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.UNIFIED, ) From 76c4f670ef8c603332a18faec045f4ff0b9b6b50 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 20:24:04 +0000 Subject: [PATCH 45/54] tweaks --- mlos_bench/mlos_bench/launcher.py | 4 +++- .../services/remote/azure/azure_deployment_services.py | 2 +- .../services/remote/azure/azure_network_services.py | 3 ++- .../mlos_bench/services/remote/azure/azure_vm_services.py | 5 ++++- mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py | 3 ++- mlos_bench/mlos_bench/storage/sql/common.py | 4 +++- mlos_bench/mlos_bench/storage/sql/experiment.py | 5 ++++- .../config/schemas/optimizers/test_optimizer_schemas.py | 3 ++- .../config/schemas/schedulers/test_scheduler_schemas.py | 3 ++- .../tests/config/schemas/services/test_services_schemas.py | 3 ++- .../tests/config/schemas/storage/test_storage_schemas.py | 3 ++- .../config/services/test_load_service_config_examples.py | 4 +++- .../config/storage/test_load_storage_config_examples.py | 4 +++- .../mlos_bench/tests/environments/remote/test_ssh_env.py | 5 ++++- .../mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py | 4 +++- .../mlos_bench/tests/services/config_persistence_test.py | 3 ++- .../tests/services/local/local_exec_python_test.py | 4 +++- .../mlos_bench/tests/services/local/local_exec_test.py | 4 +++- .../tests/services/remote/ssh/test_ssh_host_service.py | 5 ++++- .../mlos_bench/tests/tunables/test_tunables_size_props.py | 3 ++- mlos_bench/mlos_bench/tunables/tunable.py | 5 ++++- mlos_core/mlos_core/spaces/adapters/__init__.py | 3 ++- .../mlos_core/tests/optimizers/bayesian_optimizers_test.py | 4 +++- mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py | 4 +++- mlos_viz/mlos_viz/base.py | 2 +- mlos_viz/mlos_viz/dabl.py | 5 ++++- 26 files changed, 71 insertions(+), 26 deletions(-) diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index 106b853043..f62170cc41 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -450,7 +450,9 @@ def _load_storage(self, args_storage: Optional[str]) -> Storage: class_config = self._config_loader.load_config(args_storage, ConfigSchema.STORAGE) assert isinstance(class_config, Dict) storage = self._config_loader.build_storage( - service=self._parent_service, config=class_config, global_config=self.global_config + service=self._parent_service, + config=class_config, + global_config=self.global_config, ) return storage diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index dfa1a60555..b7d35e2792 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -104,7 +104,7 @@ def __init__( self._deploy_params = merge_parameters(dest=deploy_params, source=global_config) else: _LOG.info( - "No deploymentTemplatePath provided. Deployment services will be unavailable." + "No deploymentTemplatePath provided. Deployment services will be unavailable.", ) @property diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py index 2ca6130a2f..29552de4f0 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_network_services.py @@ -86,7 +86,8 @@ def _set_default_params(self, params: dict) -> dict: # pylint: disable=no-self- if "vnetName" in params and "deploymentName" not in params: params["deploymentName"] = f"{params['vnetName']}-deployment" _LOG.info( - "deploymentName missing from params. Defaulting to '%s'.", params["deploymentName"] + "deploymentName missing from params. Defaulting to '%s'.", + params["deploymentName"], ) return params diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 2bc2bb35cd..09d5986ef8 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -500,7 +500,10 @@ def remote_exec( _LOG.debug("Request: POST %s\n%s", url, json.dumps(json_req, indent=2)) response = requests.post( - url, json=json_req, headers=self._get_headers(), timeout=self._request_timeout + url, + json=json_req, + headers=self._get_headers(), + timeout=self._request_timeout, ) if _LOG.isEnabledFor(logging.DEBUG): diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 8c0b2b8b7a..89888dec47 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -377,7 +377,8 @@ def _get_connect_params(self, params: dict) -> dict: connect_params["username"] = str(self.config["ssh_username"]) priv_key_file: Optional[str] = params.get( - "ssh_priv_key_path", self.config["ssh_priv_key_path"] + "ssh_priv_key_path", + self.config["ssh_priv_key_path"], ) if priv_key_file: priv_key_file = os.path.expanduser(priv_key_file) diff --git a/mlos_bench/mlos_bench/storage/sql/common.py b/mlos_bench/mlos_bench/storage/sql/common.py index e9cc9ed4b0..3b0c6c31fb 100644 --- a/mlos_bench/mlos_bench/storage/sql/common.py +++ b/mlos_bench/mlos_bench/storage/sql/common.py @@ -244,5 +244,7 @@ def get_results_df( # Concat the trials, configs, and results. return trials_df.merge(configs_df, on=["trial_id", "tunable_config_id"], how="left").merge( - results_df, on="trial_id", how="left" + results_df, + on="trial_id", + how="left", ) diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 443c4b2c82..28e3d6f358 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -256,7 +256,10 @@ def pending_trials(self, timestamp: datetime, *, running: bool) -> Iterator[Stor ) for trial in cur_trials.fetchall(): tunables = self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id + conn, + self._schema.config_param, + "param", + config_id=trial.config_id, ) config = self._get_key_val( conn, diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index 00ab6ab9d1..ce314c17b3 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -144,5 +144,6 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER ) check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.UNIFIED, ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 8b29cfbd08..1f6906725e 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -75,7 +75,8 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER ) check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.UNIFIED, ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 0f7e3ef7f2..7e234eaf0b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -96,5 +96,6 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE ) check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.UNIFIED, ) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 9d0d604c14..9bd6addab1 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -73,7 +73,8 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE ) check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.UNIFIED + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.UNIFIED, ) diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 5e9cb8ed13..0f23a749f8 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -35,7 +35,9 @@ def predicate(config_path: str) -> bool: configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, ) assert configs diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index 480b17425d..a4c85cb3da 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -28,7 +28,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, ) assert configs diff --git a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py index ecefc05cdd..e3a12bd3ed 100644 --- a/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py +++ b/mlos_bench/mlos_bench/tests/environments/remote/test_ssh_env.py @@ -39,7 +39,10 @@ def test_remote_ssh_env(ssh_test_server: SshTestServerInfo) -> None: ) config_path = service.resolve_path("environments/remote/test_ssh_env.jsonc") env = service.load_environment( - config_path, TunableGroups(), global_config=global_config, service=service + config_path, + TunableGroups(), + global_config=global_config, + service=service, ) check_env_success( diff --git a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py index 3b45d7dcd6..23aa56e48c 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mlos_core_opt_smac_test.py @@ -129,7 +129,9 @@ def test_init_mlos_core_smac_relative_output_directory_with_experiment_id( assert isinstance(opt._opt, SmacOptimizer) assert path_join(str(opt._opt.base_optimizer.scenario.output_directory)).startswith( path_join( - os.getcwd(), str(test_opt_config["output_directory"]), global_config["experiment_id"] + os.getcwd(), + str(test_opt_config["output_directory"]), + global_config["experiment_id"], ) ) shutil.rmtree(_OUTPUT_DIR) diff --git a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py index 3f8a6514ed..0be2ac7749 100644 --- a/mlos_bench/mlos_bench/tests/services/config_persistence_test.py +++ b/mlos_bench/mlos_bench/tests/services/config_persistence_test.py @@ -97,7 +97,8 @@ def test_load_config(config_persistence_service: ConfigPersistenceService) -> No `config_path`. """ tunables_data = config_persistence_service.load_config( - "tunable-values/tunable-values-example.jsonc", ConfigSchema.TUNABLE_VALUES + "tunable-values/tunable-values-example.jsonc", + ConfigSchema.TUNABLE_VALUES, ) assert tunables_data is not None assert isinstance(tunables_data, dict) diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py index c52e643025..e3890149bd 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_python_test.py @@ -54,7 +54,9 @@ def test_run_python_script(local_exec_service: LocalExecService) -> None: ) (return_code, _stdout, stderr) = local_exec_service.local_exec( - [f"{script_path} {input_file} {meta_file} {output_file}"], cwd=temp_dir, env=params + [f"{script_path} {input_file} {meta_file} {output_file}"], + cwd=temp_dir, + env=params, ) assert stderr.strip() == "" diff --git a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py index 6a64398fc3..7165496f9d 100644 --- a/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py +++ b/mlos_bench/mlos_bench/tests/services/local/local_exec_test.py @@ -212,5 +212,7 @@ def test_temp_dir_path_expansion() -> None: # pylint: disable=protected-access assert isinstance(local_exec_service._temp_dir, str) assert path_join(local_exec_service._temp_dir, abs_path=True) == path_join( - temp_dir, "temp", abs_path=True + temp_dir, + "temp", + abs_path=True, ) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index 54ceb9984e..fa935563d7 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -240,5 +240,8 @@ def test_ssh_service_reboot( locked_docker_services, reboot_test_server, ssh_host_service, graceful=True ) check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=False + locked_docker_services, + reboot_test_server, + ssh_host_service, + graceful=False, ) diff --git a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py index c792e82bcd..fcbca29ed9 100644 --- a/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py +++ b/mlos_bench/mlos_bench/tests/tunables/test_tunables_size_props.py @@ -80,7 +80,8 @@ def test_tunable_quantized_int_size_props() -> None: def test_tunable_quantized_float_size_props() -> None: """Test quantized tunable float size properties.""" tunable = Tunable( - name="test", config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1} + name="test", + config={"type": "float", "range": [0, 1], "default": 0, "quantization": 0.1}, ) assert tunable.span == 1 assert tunable.cardinality == 11 diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 9be5ea9f37..1a6d3a804b 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -614,7 +614,10 @@ def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: return ( float(x) for x in np.linspace( - start=num_range[0], stop=num_range[1], num=cardinality, endpoint=True + start=num_range[0], + stop=num_range[1], + num=cardinality, + endpoint=True, ) ) assert self.type == "int", f"Unhandled tunable type: {self}" diff --git a/mlos_core/mlos_core/spaces/adapters/__init__.py b/mlos_core/mlos_core/spaces/adapters/__init__.py index 3187e32bc6..1645ac9cb4 100644 --- a/mlos_core/mlos_core/spaces/adapters/__init__.py +++ b/mlos_core/mlos_core/spaces/adapters/__init__.py @@ -78,7 +78,8 @@ def create( space_adapter_kwargs = {} space_adapter: ConcreteSpaceAdapter = space_adapter_type.value( - orig_parameter_space=parameter_space, **space_adapter_kwargs + orig_parameter_space=parameter_space, + **space_adapter_kwargs, ) return space_adapter diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index 68599e176b..65f0d9ab92 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -32,7 +32,9 @@ def test_context_not_implemented_warning( if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=["score"], + **kwargs, ) suggestion, _metadata = optimizer.suggest() scores = pd.DataFrame({"score": [1]}) diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index f557b05883..a69377b815 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -466,7 +466,9 @@ def test_llamatune_pipeline( # Define config space with a mix of different parameter types input_space = construct_parameter_space( - n_continuous_params=10, n_integer_params=10, n_categorical_params=5 + n_continuous_params=10, + n_integer_params=10, + n_categorical_params=5, ) adapter = LlamaTuneAdapter( orig_parameter_space=input_space, diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index 84e1fb3bd3..c5ed8f3a90 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -457,7 +457,7 @@ def plot_top_n_configs( ) (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column( - top_n_config_results_df + top_n_config_results_df, ) top_n = len(top_n_config_results_df[groupby_column].unique()) - 1 diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 7275966350..55d88c71e9 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -49,7 +49,10 @@ def ignore_plotter_warnings() -> None: "ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers" ) warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="Not plotting highly correlated" + "ignore", + module="dabl", + category=UserWarning, + message="Not plotting highly correlated", ) warnings.filterwarnings( "ignore", From 948ba2aac0f12d205d77bf9d1e7776a27032b5b3 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 20:43:53 +0000 Subject: [PATCH 46/54] reformat some single line args to multiline --- mlos_bench/mlos_bench/launcher.py | 12 +++++-- .../remote/azure/azure_deployment_services.py | 3 +- .../mlos_bench/storage/sql/experiment.py | 10 ++++-- mlos_bench/mlos_bench/tests/__init__.py | 3 +- .../test_load_environment_config_examples.py | 4 ++- .../optimizers/test_optimizer_schemas.py | 15 +++++--- .../schedulers/test_scheduler_schemas.py | 9 +++-- .../schemas/services/test_services_schemas.py | 3 +- .../schemas/storage/test_storage_schemas.py | 9 +++-- .../environments/include_tunables_test.py | 4 ++- .../optimizers/grid_search_optimizer_test.py | 6 ++-- .../remote/azure/azure_vm_services_test.py | 30 +++++++++++----- .../remote/ssh/test_ssh_host_service.py | 8 +++-- .../mlos_bench/tunables/covariant_group.py | 4 ++- .../mlos_bench/tunables/tunable_groups.py | 4 ++- .../bayesian_optimizers/smac_optimizer.py | 34 +++++++++++++------ .../mlos_core/spaces/adapters/llamatune.py | 26 +++++++++----- .../mlos_core/spaces/converters/flaml.py | 6 ++-- .../tests/optimizers/optimizer_test.py | 19 +++++++---- .../tests/spaces/adapters/llamatune_test.py | 32 +++++++++++------ mlos_viz/mlos_viz/base.py | 10 ++++-- mlos_viz/mlos_viz/dabl.py | 10 ++++-- 22 files changed, 181 insertions(+), 80 deletions(-) diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index f62170cc41..23421f195b 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -113,7 +113,9 @@ def __init__(self, description: str, long_text: str = "", argv: Optional[List[st service_files: List[str] = config.get("services", []) + (args.service or []) assert isinstance(self._parent_service, SupportsConfigLoading) self._parent_service = self._parent_service.load_services( - service_files, self.global_config, self._parent_service + service_files, + self.global_config, + self._parent_service, ) env_path = args.environment or config.get("environment") @@ -164,7 +166,8 @@ def service(self) -> Service: @staticmethod def _parse_args( - parser: argparse.ArgumentParser, argv: Optional[List[str]] + parser: argparse.ArgumentParser, + argv: Optional[List[str]], ) -> Tuple[argparse.Namespace, List[str]]: """Parse the command line arguments.""" parser.add_argument( @@ -376,7 +379,10 @@ def _load_config( return global_config def _init_tunable_values( - self, random_init: bool, seed: Optional[int], args_tunables: Optional[str] + self, + random_init: bool, + seed: Optional[int], + args_tunables: Optional[str], ) -> TunableGroups: """Initialize the tunables and load key/value pairs of the tunable values from given JSON files, if specified. diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py index b7d35e2792..9503d11409 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_deployment_services.py @@ -92,7 +92,8 @@ def __init__( if self.config.get("deploymentTemplatePath") is not None: # TODO: Provide external schema validation? template = self.config_loader_service.load_config( - self.config["deploymentTemplatePath"], schema_type=None + self.config["deploymentTemplatePath"], + schema_type=None, ) assert template is not None and isinstance(template, dict) self._deploy_template = template diff --git a/mlos_bench/mlos_bench/storage/sql/experiment.py b/mlos_bench/mlos_bench/storage/sql/experiment.py index 28e3d6f358..ffd26e6d39 100644 --- a/mlos_bench/mlos_bench/storage/sql/experiment.py +++ b/mlos_bench/mlos_bench/storage/sql/experiment.py @@ -182,7 +182,10 @@ def load( trial_ids.append(trial.trial_id) configs.append( self._get_key_val( - conn, self._schema.config_param, "param", config_id=trial.config_id + conn, + self._schema.config_param, + "param", + config_id=trial.config_id, ) ) if stat.is_succeeded(): @@ -223,7 +226,10 @@ def _get_key_val(conn: Connection, table: Table, field: str, **kwargs: Any) -> D @staticmethod def _save_params( - conn: Connection, table: Table, params: Dict[str, Any], **kwargs: Any + conn: Connection, + table: Table, + params: Dict[str, Any], + **kwargs: Any, ) -> None: if not params: return diff --git a/mlos_bench/mlos_bench/tests/__init__.py b/mlos_bench/mlos_bench/tests/__init__.py index 4fca4fc449..8737057665 100644 --- a/mlos_bench/mlos_bench/tests/__init__.py +++ b/mlos_bench/mlos_bench/tests/__init__.py @@ -51,7 +51,8 @@ debug("Docker is available but missing support for targeting linux platform.") DOCKER = None requires_docker = pytest.mark.skipif( - not DOCKER, reason="Docker with Linux support is not available on this system." + not DOCKER, + reason="Docker with Linux support is not available on this system.", ) # A decorator for tests that require ssh. diff --git a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py index c4a29d8984..fe5e651d95 100644 --- a/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/environments/test_load_environment_config_examples.py @@ -34,7 +34,9 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: configs = locate_config_examples( - ConfigPersistenceService.BUILTIN_CONFIG_PATH, CONFIG_TYPE, filter_configs + ConfigPersistenceService.BUILTIN_CONFIG_PATH, + CONFIG_TYPE, + filter_configs, ) assert configs diff --git a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py index ce314c17b3..87c7dd7a27 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/optimizers/test_optimizer_schemas.py @@ -34,7 +34,8 @@ expected_mlos_bench_optimizer_class_names = [ subclass.__module__ + "." + subclass.__name__ for subclass in get_all_concrete_subclasses( - Optimizer, pkg_name="mlos_bench" # type: ignore[type-abstract] + Optimizer, # type: ignore[type-abstract] + pkg_name="mlos_bench", ) ] assert expected_mlos_bench_optimizer_class_names @@ -53,7 +54,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_optimizer_type", expected_mlos_bench_optimizer_class_names) def test_case_coverage_mlos_bench_optimizer_type( - test_case_subtype: str, mlos_bench_optimizer_type: str + test_case_subtype: str, + mlos_bench_optimizer_type: str, ) -> None: """Checks to see if there is a given type of test case for the given mlos_bench optimizer type. @@ -75,7 +77,8 @@ def test_case_coverage_mlos_bench_optimizer_type( # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_optimizer_type", expected_mlos_core_optimizer_types) def test_case_coverage_mlos_core_optimizer_type( - test_case_type: str, mlos_core_optimizer_type: Optional[OptimizerType] + test_case_type: str, + mlos_core_optimizer_type: Optional[OptimizerType], ) -> None: """Checks to see if there is a given type of test case for the given mlos_core optimizer type. @@ -101,7 +104,8 @@ def test_case_coverage_mlos_core_optimizer_type( # @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_core_space_adapter_type", expected_mlos_core_space_adapter_types) def test_case_coverage_mlos_core_space_adapter_type( - test_case_type: str, mlos_core_space_adapter_type: Optional[SpaceAdapterType] + test_case_type: str, + mlos_core_space_adapter_type: Optional[SpaceAdapterType], ) -> None: """Checks to see if there is a given type of test case for the given mlos_core space adapter type. @@ -141,7 +145,8 @@ def test_optimizer_configs_with_extra_param(test_case_name: str) -> None: certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.OPTIMIZER + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.OPTIMIZER, ) check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py index 1f6906725e..56945739d7 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test_scheduler_schemas.py @@ -31,7 +31,8 @@ expected_mlos_bench_scheduler_class_names = [ subclass.__module__ + "." + subclass.__name__ for subclass in get_all_concrete_subclasses( - Scheduler, pkg_name="mlos_bench" # type: ignore[type-abstract] + Scheduler, # type: ignore[type-abstract] + pkg_name="mlos_bench", ) ] assert expected_mlos_bench_scheduler_class_names @@ -42,7 +43,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_scheduler_type", expected_mlos_bench_scheduler_class_names) def test_case_coverage_mlos_bench_scheduler_type( - test_case_subtype: str, mlos_bench_scheduler_type: str + test_case_subtype: str, + mlos_bench_scheduler_type: str, ) -> None: """Checks to see if there is a given type of test case for the given mlos_bench scheduler type. @@ -72,7 +74,8 @@ def test_scheduler_configs_with_extra_param(test_case_name: str) -> None: certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SCHEDULER + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.SCHEDULER, ) check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], diff --git a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py index 7e234eaf0b..e8b95ad85b 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/services/test_services_schemas.py @@ -93,7 +93,8 @@ def test_service_configs_with_extra_param(test_case_name: str) -> None: certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.SERVICE + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.SERVICE, ) check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], diff --git a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py index 9bd6addab1..c3dd4ced81 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/storage/test_storage_schemas.py @@ -29,7 +29,8 @@ expected_mlos_bench_storage_class_names = [ subclass.__module__ + "." + subclass.__name__ for subclass in get_all_concrete_subclasses( - Storage, pkg_name="mlos_bench" # type: ignore[type-abstract] + Storage, # type: ignore[type-abstract] + pkg_name="mlos_bench", ) ] assert expected_mlos_bench_storage_class_names @@ -40,7 +41,8 @@ @pytest.mark.parametrize("test_case_subtype", sorted(TEST_CASES.by_subtype)) @pytest.mark.parametrize("mlos_bench_storage_type", expected_mlos_bench_storage_class_names) def test_case_coverage_mlos_bench_storage_type( - test_case_subtype: str, mlos_bench_storage_type: str + test_case_subtype: str, + mlos_bench_storage_type: str, ) -> None: """Checks to see if there is a given type of test case for the given mlos_bench storage type. @@ -70,7 +72,8 @@ def test_storage_configs_with_extra_param(test_case_name: str) -> None: certain places. """ check_test_case_config_with_extra_param( - TEST_CASES.by_type["good"][test_case_name], ConfigSchema.STORAGE + TEST_CASES.by_type["good"][test_case_name], + ConfigSchema.STORAGE, ) check_test_case_config_with_extra_param( TEST_CASES.by_type["good"][test_case_name], diff --git a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py index a3df4cb558..4c4fcd5dae 100644 --- a/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py +++ b/mlos_bench/mlos_bench/tests/environments/include_tunables_test.py @@ -12,7 +12,9 @@ def test_one_group(tunable_groups: TunableGroups) -> None: """Make sure only one tunable group is available to the environment.""" env = MockEnv( - name="Test Env", config={"tunable_params": ["provision"]}, tunables=tunable_groups + name="Test Env", + config={"tunable_params": ["provision"]}, + tunables=tunable_groups, ) assert env.tunable_params.get_param_values() == { "vmSize": "Standard_B4ms", diff --git a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py index 8761201c8e..769bf8859d 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/grid_search_optimizer_test.py @@ -75,7 +75,8 @@ def grid_search_tunables(grid_search_tunables_config: dict) -> TunableGroups: @pytest.fixture def grid_search_opt( - grid_search_tunables: TunableGroups, grid_search_tunables_grid: List[Dict[str, TunableValue]] + grid_search_tunables: TunableGroups, + grid_search_tunables_grid: List[Dict[str, TunableValue]], ) -> GridSearchOptimizer: """Test fixture for grid search optimizer.""" assert len(grid_search_tunables) == 3 @@ -280,7 +281,8 @@ def test_grid_search_async_order(grid_search_opt: GridSearchOptimizer) -> None: def test_grid_search_register( - grid_search_opt: GridSearchOptimizer, grid_search_tunables: TunableGroups + grid_search_opt: GridSearchOptimizer, + grid_search_tunables: TunableGroups, ) -> None: """Make sure that the `.register()` method adjusts the score signs correctly.""" assert grid_search_opt.register( diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py index 33f25f48c8..6418da01a9 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_vm_services_test.py @@ -37,10 +37,12 @@ def test_wait_host_deployment_retry( mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") + "Connection aborted", + OSError(107, "Transport endpoint is not connected"), ), requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") + "Connection aborted", + OSError(107, "Transport endpoint is not connected"), ), make_httplib_json_response(200, {"properties": {"provisioningState": "Running"}}), make_httplib_json_response(200, {"properties": {"provisioningState": "Succeeded"}}), @@ -166,7 +168,9 @@ def test_vm_operation_status( ], ) def test_vm_operation_invalid( - azure_vm_service_remote_exec_only: AzureVMService, operation_name: str, accepts_params: bool + azure_vm_service_remote_exec_only: AzureVMService, + operation_name: str, + accepts_params: bool, ) -> None: """Test VM operation status for an incomplete service config.""" operation = getattr(azure_vm_service_remote_exec_only, operation_name) @@ -177,7 +181,9 @@ def test_vm_operation_invalid( @patch("mlos_bench.services.remote.azure.azure_deployment_services.time.sleep") @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") def test_wait_vm_operation_ready( - mock_session: MagicMock, mock_sleep: MagicMock, azure_vm_service: AzureVMService + mock_session: MagicMock, + mock_sleep: MagicMock, + azure_vm_service: AzureVMService, ) -> None: """Test waiting for the completion of the remote VM operation.""" # Mock response header @@ -204,7 +210,8 @@ def test_wait_vm_operation_ready( @patch("mlos_bench.services.remote.azure.azure_deployment_services.requests.Session") def test_wait_vm_operation_timeout( - mock_session: MagicMock, azure_vm_service: AzureVMService + mock_session: MagicMock, + azure_vm_service: AzureVMService, ) -> None: """Test the time out of the remote VM operation.""" # Mock response header @@ -241,10 +248,12 @@ def test_wait_vm_operation_retry( mock_getconn.return_value.getresponse.side_effect = [ make_httplib_json_response(200, {"status": "InProgress"}), requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") + "Connection aborted", + OSError(107, "Transport endpoint is not connected"), ), requests_ex.ConnectionError( - "Connection aborted", OSError(107, "Transport endpoint is not connected") + "Connection aborted", + OSError(107, "Transport endpoint is not connected"), ), make_httplib_json_response(200, {"status": "InProgress"}), make_httplib_json_response(200, {"status": "Succeeded"}), @@ -290,7 +299,9 @@ def test_remote_exec_status( mock_requests.post.return_value = mock_response status, _ = azure_vm_service_remote_exec_only.remote_exec( - script, config={"vmName": "test-vm"}, env_params={} + script, + config={"vmName": "test-vm"}, + env_params={}, ) assert status == operation_status @@ -298,7 +309,8 @@ def test_remote_exec_status( @patch("mlos_bench.services.remote.azure.azure_vm_services.requests") def test_remote_exec_headers_output( - mock_requests: MagicMock, azure_vm_service_remote_exec_only: AzureVMService + mock_requests: MagicMock, + azure_vm_service_remote_exec_only: AzureVMService, ) -> None: """Check if HTTP headers from the remote execution on Azure are correct.""" async_url_key = "asyncResultsUrl" diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py index fa935563d7..003a8e6433 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_host_service.py @@ -167,7 +167,8 @@ def check_ssh_service_reboot( # Now try to restart the server. (status, reboot_results_info) = ssh_host_service.reboot( - params=reboot_test_srv_ssh_svc_conf, force=not graceful + params=reboot_test_srv_ssh_svc_conf, + force=not graceful, ) assert status.is_pending() @@ -237,7 +238,10 @@ def test_ssh_service_reboot( """Test the SshHostService reboot operation.""" # Grouped together to avoid parallel runner interactions. check_ssh_service_reboot( - locked_docker_services, reboot_test_server, ssh_host_service, graceful=True + locked_docker_services, + reboot_test_server, + ssh_host_service, + graceful=True, ) check_ssh_service_reboot( locked_docker_services, diff --git a/mlos_bench/mlos_bench/tunables/covariant_group.py b/mlos_bench/mlos_bench/tunables/covariant_group.py index 1468ce5545..b30c879d8f 100644 --- a/mlos_bench/mlos_bench/tunables/covariant_group.py +++ b/mlos_bench/mlos_bench/tunables/covariant_group.py @@ -236,7 +236,9 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self.get_tunable(tunable).value def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + self, + tunable: Union[str, Tunable], + tunable_value: Union[TunableValue, Tunable], ) -> TunableValue: value: TunableValue = ( tunable_value.value if isinstance(tunable_value, Tunable) else tunable_value diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index 684d15f120..b3e3698c61 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -157,7 +157,9 @@ def __getitem__(self, tunable: Union[str, Tunable]) -> TunableValue: return self._index[name][name] def __setitem__( - self, tunable: Union[str, Tunable], tunable_value: Union[TunableValue, Tunable] + self, + tunable: Union[str, Tunable], + tunable_value: Union[TunableValue, Tunable], ) -> TunableValue: """Update the current value of a single tunable parameter.""" # Use double index to make sure we set the is_updated flag of the group diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 7833ab31eb..611dc04044 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -130,9 +130,8 @@ def __init__( if output_directory is None: # pylint: disable=consider-using-with try: - self._temp_output_directory = TemporaryDirectory( - ignore_cleanup_errors=True - ) # Argument added in Python 3.10 + # Argument added in Python 3.10 + self._temp_output_directory = TemporaryDirectory(ignore_cleanup_errors=True) except TypeError: self._temp_output_directory = TemporaryDirectory() output_directory = self._temp_output_directory.name @@ -155,10 +154,12 @@ def __init__( n_workers=1, # Use a single thread for evaluating trials ) intensifier: AbstractIntensifier = Optimizer_Smac.get_intensifier( - scenario, max_config_calls=1 + scenario, + max_config_calls=1, ) config_selector: ConfigSelector = Optimizer_Smac.get_config_selector( - scenario, retrain_after=1 + scenario, + retrain_after=1, ) # TODO: When bulk registering prior configs to rewarm the optimizer, @@ -207,7 +208,8 @@ def __init__( # get_random_design static method when random_design is None. assert isinstance(n_random_probability, float) and n_random_probability >= 0 random_design = ProbabilityRandomDesign( - probability=n_random_probability, seed=scenario.seed + probability=n_random_probability, + seed=scenario.seed, ) self.base_optimizer = Optimizer_Smac( @@ -218,7 +220,8 @@ def __init__( random_design=random_design, config_selector=config_selector, multi_objective_algorithm=Optimizer_Smac.get_multi_objective_algorithm( - scenario, objective_weights=self._objective_weights + scenario, + objective_weights=self._objective_weights, ), overwrite=True, logging_level=False, # Use the existing logger @@ -309,7 +312,8 @@ def _register( # Retrieve previously generated TrialInfo (returned by .ask()) or create # new TrialInfo instance info: TrialInfo = self.trial_info_map.get( - config, TrialInfo(config=config, seed=self.base_optimizer.scenario.seed) + config, + TrialInfo(config=config, seed=self.base_optimizer.scenario.seed), ) value = TrialValue(cost=list(score.astype(float)), time=0.0, status=StatusType.SUCCESS) self.base_optimizer.tell(info, value, save=False) @@ -318,7 +322,9 @@ def _register( self.base_optimizer.optimizer.save() def _suggest( - self, *, context: Optional[pd.DataFrame] = None + self, + *, + context: Optional[pd.DataFrame] = None, ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. @@ -363,7 +369,10 @@ def register_pending( raise NotImplementedError() def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, ) -> npt.NDArray: # pylint: disable=import-outside-toplevel from smac.utils.configspace import convert_configurations_to_array @@ -392,7 +401,10 @@ def surrogate_predict( ) def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, ) -> npt.NDArray: if context is not None: warn(f"Not Implemented: Ignoring context {list(context.columns)}", UserWarning) diff --git a/mlos_core/mlos_core/spaces/adapters/llamatune.py b/mlos_core/mlos_core/spaces/adapters/llamatune.py index e304c0dd50..38d973a27f 100644 --- a/mlos_core/mlos_core/spaces/adapters/llamatune.py +++ b/mlos_core/mlos_core/spaces/adapters/llamatune.py @@ -99,7 +99,8 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: target_configurations = [] for _, config in configurations.astype("O").iterrows(): configuration = ConfigSpace.Configuration( - self.orig_parameter_space, values=config.to_dict() + self.orig_parameter_space, + values=config.to_dict(), ) target_config = self._suggested_configs.get(configuration, None) @@ -135,7 +136,8 @@ def inverse_transform(self, configurations: pd.DataFrame) -> pd.DataFrame: vector = self._config_scaler.inverse_transform([config_vector])[0] target_config_vector = self._pinv_matrix.dot(vector) target_config = ConfigSpace.Configuration( - self.target_parameter_space, vector=target_config_vector + self.target_parameter_space, + vector=target_config_vector, ) target_configurations.append(target_config) @@ -153,7 +155,8 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: target_values_dict = configuration.iloc[0].to_dict() target_configuration = ConfigSpace.Configuration( - self.target_parameter_space, values=target_values_dict + self.target_parameter_space, + values=target_values_dict, ) orig_values_dict = self._transform(target_values_dict) @@ -167,7 +170,9 @@ def transform(self, configuration: pd.DataFrame) -> pd.DataFrame: ) def _construct_low_dim_space( - self, num_low_dims: int, max_unique_values_per_param: Optional[int] + self, + num_low_dims: int, + max_unique_values_per_param: Optional[int], ) -> None: """ Constructs the low-dimensional parameter (potentially discretized) search space. @@ -197,7 +202,9 @@ def _construct_low_dim_space( # range, used by HeSBO projection. hyperparameters = [ ConfigSpace.UniformIntegerHyperparameter( - name=f"dim_{idx}", lower=1, upper=max_unique_values_per_param + name=f"dim_{idx}", + lower=1, + upper=max_unique_values_per_param, ) for idx in range(num_low_dims) ] @@ -213,9 +220,8 @@ def _construct_low_dim_space( # Construct low-dimensional parameter search space config_space = ConfigSpace.ConfigurationSpace(name=self.orig_parameter_space.name) - config_space.random = ( - self._random_state - ) # use same random state as in original parameter space + # use same random state as in original parameter space + config_space.random = self._random_state config_space.add_hyperparameters(hyperparameters) self._target_config_space = config_space @@ -278,7 +284,9 @@ def _transform(self, configuration: dict) -> dict: return original_config def _special_param_value_scaler( - self, param: ConfigSpace.UniformIntegerHyperparameter, input_value: float + self, + param: ConfigSpace.UniformIntegerHyperparameter, + input_value: float, ) -> float: """ Biases the special value(s) of this parameter, by shifting the normalized diff --git a/mlos_core/mlos_core/spaces/converters/flaml.py b/mlos_core/mlos_core/spaces/converters/flaml.py index 8b669f98e4..71370853e4 100644 --- a/mlos_core/mlos_core/spaces/converters/flaml.py +++ b/mlos_core/mlos_core/spaces/converters/flaml.py @@ -52,11 +52,13 @@ def _one_parameter_convert(parameter: "Hyperparameter") -> FlamlDomain: if isinstance(parameter, ConfigSpace.UniformFloatHyperparameter): # FIXME: upper isn't included in the range return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper + parameter.lower, + parameter.upper, ) elif isinstance(parameter, ConfigSpace.UniformIntegerHyperparameter): return flaml_numeric_type[(type(parameter), parameter.log)]( - parameter.lower, parameter.upper + 1 + parameter.lower, + parameter.upper + 1, ) elif isinstance(parameter, ConfigSpace.CategoricalHyperparameter): if len(np.unique(parameter.probabilities)) > 1: diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index 7233918673..a6aa77087c 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -45,7 +45,9 @@ def test_create_optimizer_and_suggest( if kwargs is None: kwargs = {} optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=["score"], + **kwargs, ) assert optimizer is not None @@ -91,7 +93,9 @@ def objective(x: pd.Series) -> pd.DataFrame: # Emukit doesn't allow specifying a random state, so we set the global seed. np.random.seed(SEED) optimizer = optimizer_class( - parameter_space=configuration_space, optimization_targets=["score"], **kwargs + parameter_space=configuration_space, + optimization_targets=["score"], + **kwargs, ) with pytest.raises(ValueError, match="No observations"): @@ -294,9 +298,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # loop for llamatune-optimizer suggestion, metadata = llamatune_optimizer.suggest() _x, _y = suggestion["x"].iloc[0], suggestion["y"].iloc[0] - assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx( - 3.0, rel=1e-3 - ) # optimizer explores 1-dimensional space + # optimizer explores 1-dimensional space + assert _x == pytest.approx(_y, rel=1e-3) or _x + _y == pytest.approx(3.0, rel=1e-3) observation = objective(suggestion) llamatune_optimizer.register(configs=suggestion, scores=observation, metadata=metadata) @@ -343,7 +346,8 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: # Dynamically determine all of the optimizers we have implemented. # Note: these must be sorted. optimizer_subclasses: List[Type[BaseOptimizer]] = get_all_concrete_subclasses( - BaseOptimizer, pkg_name="mlos_core" # type: ignore[type-abstract] + BaseOptimizer, # type: ignore[type-abstract] + pkg_name="mlos_core", ) assert optimizer_subclasses @@ -366,7 +370,8 @@ def test_optimizer_type_defs(optimizer_class: Type[BaseOptimizer]) -> None: ], ) def test_mixed_numerics_type_input_space_types( - optimizer_type: Optional[OptimizerType], kwargs: Optional[dict] + optimizer_type: Optional[OptimizerType], + kwargs: Optional[dict], ) -> None: """Toy problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py index a69377b815..9d73b6ba8c 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/llamatune_test.py @@ -64,7 +64,8 @@ def construct_parameter_space( ), ) def test_num_low_dims( - num_target_space_dims: int, param_space_kwargs: dict + num_target_space_dims: int, + param_space_kwargs: dict, ) -> None: # pylint: disable=too-many-locals """Tests LlamaTune's low-to-high space projection method.""" input_space = construct_parameter_space(**param_space_kwargs) @@ -100,7 +101,8 @@ def test_num_low_dims( # Sampled config and this should be the same target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + adapter.target_parameter_space, + values=target_config_df.iloc[0].to_dict(), ) assert target_config == sampled_config @@ -213,7 +215,8 @@ def gen_random_configs(adapter: LlamaTuneAdapter, num_configs: int) -> Iterator[ ) orig_config_df = adapter.transform(sampled_config_df) orig_config = CS.Configuration( - adapter.orig_parameter_space, values=orig_config_df.iloc[0].to_dict() + adapter.orig_parameter_space, + values=orig_config_df.iloc[0].to_dict(), ) yield orig_config @@ -320,7 +323,9 @@ def test_max_unique_values_per_param() -> None: """Tests LlamaTune's parameter values discretization implementation.""" # Define config space with a mix of different parameter types input_space = CS.ConfigurationSpace(seed=1234) - input_space.add_hyperparameter(CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5)) + input_space.add_hyperparameter( + CS.UniformFloatHyperparameter(name="cont_1", lower=0, upper=5), + ) input_space.add_hyperparameter( CS.UniformFloatHyperparameter(name="cont_2", lower=1, upper=100) ) @@ -380,7 +385,8 @@ def test_max_unique_values_per_param() -> None: ), ) def test_approx_inverse_mapping( - num_target_space_dims: int, param_space_kwargs: dict + num_target_space_dims: int, + param_space_kwargs: dict, ) -> None: # pylint: disable=too-many-locals """Tests LlamaTune's approximate high-to-low space projection method, using pseudo- inverse. @@ -421,7 +427,8 @@ def test_approx_inverse_mapping( target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + adapter.target_parameter_space, + values=target_config_df.iloc[0].to_dict(), ) adapter.target_parameter_space.check_configuration(target_config) @@ -434,7 +441,8 @@ def test_approx_inverse_mapping( target_config_df = adapter.inverse_transform(sampled_config_df) # Low-dim (i.e., target) config should be valid target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + adapter.target_parameter_space, + values=target_config_df.iloc[0].to_dict(), ) adapter.target_parameter_space.check_configuration(target_config) @@ -459,7 +467,9 @@ def test_approx_inverse_mapping( ), ) def test_llamatune_pipeline( - num_low_dims: int, special_param_values: dict, max_unique_values_per_param: int + num_low_dims: int, + special_param_values: dict, + max_unique_values_per_param: int, ) -> None: """Tests LlamaTune space adapter when all components are active.""" # pylint: disable=too-many-locals @@ -499,7 +509,8 @@ def test_llamatune_pipeline( target_config_df = adapter.inverse_transform(orig_config_df) # Sampled config and this should be the same target_config = CS.Configuration( - adapter.target_parameter_space, values=target_config_df.iloc[0].to_dict() + adapter.target_parameter_space, + values=target_config_df.iloc[0].to_dict(), ) assert target_config == config @@ -550,7 +561,8 @@ def test_llamatune_pipeline( ), ) def test_deterministic_behavior_for_same_seed( - num_target_space_dims: int, param_space_kwargs: dict + num_target_space_dims: int, + param_space_kwargs: dict, ) -> None: """Tests LlamaTune's space adapter deterministic behavior when given same seed in the input parameter space. diff --git a/mlos_viz/mlos_viz/base.py b/mlos_viz/mlos_viz/base.py index c5ed8f3a90..0c6d58cd7f 100644 --- a/mlos_viz/mlos_viz/base.py +++ b/mlos_viz/mlos_viz/base.py @@ -246,7 +246,9 @@ def limit_top_n_configs( # Prepare the orderby columns. (results_df, objs_cols) = expand_results_data_args( - exp_data, results_df=results_df, objectives=objectives + exp_data, + results_df=results_df, + objectives=objectives, ) assert isinstance(results_df, pandas.DataFrame) @@ -327,7 +329,8 @@ def limit_top_n_configs( # Place the default config at the top of the list. top_n_config_ids.insert(0, default_config_id) top_n_config_results_df = pandas.concat( - [default_config_results_df, top_n_config_results_df], axis=0 + [default_config_results_df, top_n_config_results_df], + axis=0, ) return (top_n_config_results_df, top_n_config_ids, orderby_cols) @@ -453,7 +456,8 @@ def plot_top_n_configs( if "objectives" not in top_n_config_args: top_n_config_args["objectives"] = objectives (top_n_config_results_df, _top_n_config_ids, orderby_cols) = limit_top_n_configs( - exp_data=exp_data, **top_n_config_args + exp_data=exp_data, + **top_n_config_args, ) (top_n_config_results_df, _groupby_columns, groupby_column) = _add_groupby_desc_column( diff --git a/mlos_viz/mlos_viz/dabl.py b/mlos_viz/mlos_viz/dabl.py index 55d88c71e9..3f8ac640ad 100644 --- a/mlos_viz/mlos_viz/dabl.py +++ b/mlos_viz/mlos_viz/dabl.py @@ -43,10 +43,16 @@ def ignore_plotter_warnings() -> None: # pylint: disable=import-outside-toplevel warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="Could not infer format" + "ignore", + module="dabl", + category=UserWarning, + message="Could not infer format", ) warnings.filterwarnings( - "ignore", module="dabl", category=UserWarning, message="(Dropped|Discarding) .* outliers" + "ignore", + module="dabl", + category=UserWarning, + message="(Dropped|Discarding) .* outliers", ) warnings.filterwarnings( "ignore", From b17696b68a8761828a2c0ed0fa7590fb774678a9 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 20:59:14 +0000 Subject: [PATCH 47/54] CI fixups (#783) 1. Temporarily avoid recent asyncssh version that breaks the tests 2. Workaround pylance issue with pyproject.toml related changes and pip editable modules install format (#768) See Also: - https://github.com/microsoft/pylance-release/issues/3473 May also affect `mypy`: - https://github.com/python/mypy/issues/16988 - https://github.com/python/mypy/issues/12313 --- conda-envs/mlos-3.10.yml | 9 ++++++--- conda-envs/mlos-3.11.yml | 9 ++++++--- conda-envs/mlos-3.8.yml | 9 ++++++--- conda-envs/mlos-3.9.yml | 9 ++++++--- conda-envs/mlos-windows.yml | 9 ++++++--- conda-envs/mlos.yml | 9 ++++++--- mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py | 5 ----- mlos_bench/pyproject.toml | 2 +- mlos_bench/setup.py | 2 +- mlos_core/pyproject.toml | 2 +- mlos_viz/pyproject.toml | 2 +- 11 files changed, 40 insertions(+), 27 deletions(-) diff --git a/conda-envs/mlos-3.10.yml b/conda-envs/mlos-3.10.yml index 75bf64c5bf..b76d48e5b5 100644 --- a/conda-envs/mlos-3.10.yml +++ b/conda-envs/mlos-3.10.yml @@ -38,6 +38,9 @@ dependencies: - types-pygments - types-requests - types-setuptools - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/conda-envs/mlos-3.11.yml b/conda-envs/mlos-3.11.yml index 6443c7a308..64ee6fd58a 100644 --- a/conda-envs/mlos-3.11.yml +++ b/conda-envs/mlos-3.11.yml @@ -38,6 +38,9 @@ dependencies: - types-pygments - types-requests - types-setuptools - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/conda-envs/mlos-3.8.yml b/conda-envs/mlos-3.8.yml index 8b79aad2c4..b1e14c7402 100644 --- a/conda-envs/mlos-3.8.yml +++ b/conda-envs/mlos-3.8.yml @@ -38,6 +38,9 @@ dependencies: - types-pygments - types-requests - types-setuptools - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/conda-envs/mlos-3.9.yml b/conda-envs/mlos-3.9.yml index 88b384a428..edccdab405 100644 --- a/conda-envs/mlos-3.9.yml +++ b/conda-envs/mlos-3.9.yml @@ -38,6 +38,9 @@ dependencies: - types-pygments - types-requests - types-setuptools - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/conda-envs/mlos-windows.yml b/conda-envs/mlos-windows.yml index 1287247641..107b6fb2cf 100644 --- a/conda-envs/mlos-windows.yml +++ b/conda-envs/mlos-windows.yml @@ -42,6 +42,9 @@ dependencies: - types-requests - types-setuptools - pyarrow - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/conda-envs/mlos.yml b/conda-envs/mlos.yml index a65633fcfe..5cd35fdbba 100644 --- a/conda-envs/mlos.yml +++ b/conda-envs/mlos.yml @@ -37,6 +37,9 @@ dependencies: - types-pygments - types-requests - types-setuptools - - "--editable ../mlos_core[full-tests]" - - "--editable ../mlos_bench[full-tests]" - - "--editable ../mlos_viz[full-tests]" + # Workaround a pylance issue in vscode that prevents it finding the latest + # method of pip installing editable modules. + # https://github.com/microsoft/pylance-release/issues/3473 + - "--config-settings editable_mode=compat --editable ../mlos_core[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_bench[full-tests]" + - "--config-settings editable_mode=compat --editable ../mlos_viz[full-tests]" diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index 80a9d4629e..e46e932f85 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -16,12 +16,7 @@ class OneShotOptimizer(MockOptimizer): """ -<<<<<<< HEAD - Mock optimizer that proposes a single configuration and returns. - -======= No-op optimizer that proposes a single configuration and returns. ->>>>>>> main Explicit configs (partial or full) are possible using configuration files. """ diff --git a/mlos_bench/pyproject.toml b/mlos_bench/pyproject.toml index 321d946f3c..911433da2d 100644 --- a/mlos_bench/pyproject.toml +++ b/mlos_bench/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm>=8.1.0", "wheel"] +requires = ["setuptools>64", "setuptools-scm>=8.1.0", "wheel"] build-backend = "setuptools.build_meta" [project] diff --git a/mlos_bench/setup.py b/mlos_bench/setup.py index f86b7a9663..bac41680de 100644 --- a/mlos_bench/setup.py +++ b/mlos_bench/setup.py @@ -68,7 +68,7 @@ def _get_long_desc_from_readme(base_url: str) -> dict: extra_requires: Dict[str, List[str]] = { # pylint: disable=consider-using-namedtuple-or-dataclass # Additional tools for extra functionality. "azure": ["azure-storage-file-share", "azure-identity", "azure-keyvault"], - "ssh": ["asyncssh"], + "ssh": ["asyncssh<2.15.0"], # FIXME: asyncssh 2.15.0 has a bug that breaks the tests "storage-sql-duckdb": ["sqlalchemy", "duckdb_engine"], "storage-sql-mysql": ["sqlalchemy", "mysql-connector-python"], "storage-sql-postgres": ["sqlalchemy", "psycopg2"], diff --git a/mlos_core/pyproject.toml b/mlos_core/pyproject.toml index 49e7a5ff23..b55955133f 100644 --- a/mlos_core/pyproject.toml +++ b/mlos_core/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm>=8.1.0", "wheel"] +requires = ["setuptools>64", "setuptools-scm>=8.1.0", "wheel"] build-backend = "setuptools.build_meta" [project] diff --git a/mlos_viz/pyproject.toml b/mlos_viz/pyproject.toml index c0101fbbee..b469e60ecc 100644 --- a/mlos_viz/pyproject.toml +++ b/mlos_viz/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "setuptools-scm>=8.1.0", "wheel"] +requires = ["setuptools>64", "setuptools-scm>=8.1.0", "wheel"] build-backend = "setuptools.build_meta" [project] From 2670ab7e63dbdec60d91ef8d14cafe7c4b35972d Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:02:19 +0000 Subject: [PATCH 48/54] formats --- mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py | 1 + mlos_bench/mlos_bench/services/config_persistence.py | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py index e46e932f85..f41114c185 100644 --- a/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py +++ b/mlos_bench/mlos_bench/optimizers/one_shot_optimizer.py @@ -17,6 +17,7 @@ class OneShotOptimizer(MockOptimizer): """ No-op optimizer that proposes a single configuration and returns. + Explicit configs (partial or full) are possible using configuration files. """ diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index ee1b7f7902..8e8a05b0e8 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -333,9 +333,7 @@ def build_storage( """ (class_name, class_config) = self.prepare_class_load(config, global_config) # pylint: disable=import-outside-toplevel - from mlos_bench.storage.base_storage import ( - Storage, - ) + from mlos_bench.storage.base_storage import Storage inst = instantiate_from_config( Storage, # type: ignore[type-abstract] From d892b58fd4fa44a159bea045653b00d622da7d7f Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:20:35 +0000 Subject: [PATCH 49/54] tweak --- mlos_bench/mlos_bench/tests/config/schemas/__init__.py | 6 ++++-- .../config/services/test_load_service_config_examples.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py index bd6921d8c2..02cdc4fdee 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/__init__.py +++ b/mlos_bench/mlos_bench/tests/config/schemas/__init__.py @@ -123,7 +123,8 @@ def get_schema_test_cases(test_cases_root: str) -> TestCases: def check_test_case_against_schema( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema + test_case: SchemaTestCaseInfo, + schema_type: ConfigSchema, ) -> None: """ Checks the given test case against the given schema. @@ -150,7 +151,8 @@ def check_test_case_against_schema( def check_test_case_config_with_extra_param( - test_case: SchemaTestCaseInfo, schema_type: ConfigSchema + test_case: SchemaTestCaseInfo, + schema_type: ConfigSchema, ) -> None: """Checks that the config fails to validate if extra params are present in certain places. diff --git a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py index 0f23a749f8..5545327080 100644 --- a/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/services/test_load_service_config_examples.py @@ -44,7 +44,8 @@ def predicate(config_path: str) -> bool: @pytest.mark.parametrize("config_path", configs) def test_load_service_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.SERVICE) From 77e263517b594f0013f748e1e4ac28be9477180d Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:22:24 +0000 Subject: [PATCH 50/54] reformat for multiline args --- mlos_bench/mlos_bench/dict_templater.py | 5 ++++- .../mlos_bench/services/remote/azure/azure_fileshare.py | 6 +++++- .../mlos_bench/services/remote/azure/azure_vm_services.py | 5 ++++- mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py | 3 ++- mlos_bench/mlos_bench/services/types/config_loader_type.py | 4 +++- mlos_bench/mlos_bench/storage/base_storage.py | 5 ++++- mlos_bench/mlos_bench/storage/sql/trial.py | 5 ++++- .../config/storage/test_load_storage_config_examples.py | 3 ++- mlos_bench/mlos_bench/tests/conftest.py | 3 ++- .../mlos_bench/tests/environments/composite_env_test.py | 3 ++- .../tests/environments/local/local_env_telemetry_test.py | 3 ++- .../tests/environments/local/local_fileshare_env_test.py | 3 ++- mlos_bench/mlos_bench/tests/environments/mock_env_test.py | 5 ++++- mlos_bench/mlos_bench/tests/launcher_run_test.py | 3 ++- mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py | 3 ++- .../tests/optimizers/toy_optimization_loop_test.py | 3 ++- .../tests/services/remote/azure/azure_fileshare_test.py | 4 +++- .../mlos_bench/tests/services/remote/azure/conftest.py | 3 ++- .../tests/services/remote/mock/mock_fileshare_service.py | 6 +++++- mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py | 3 ++- .../tests/services/remote/ssh/test_ssh_fileshare.py | 3 ++- mlos_bench/mlos_bench/tests/storage/exp_data_test.py | 4 +++- mlos_bench/mlos_bench/tests/storage/exp_load_test.py | 4 +++- mlos_bench/mlos_bench/tests/storage/sql/fixtures.py | 3 ++- mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py | 2 +- .../mlos_bench/tests/storage/tunable_config_data_test.py | 3 ++- .../tests/storage/tunable_config_trial_group_data_test.py | 3 ++- .../tests/tunables/tunable_group_indexing_test.py | 3 ++- .../tests/tunables/tunable_to_configspace_distr_test.py | 4 +++- .../tests/tunables/tunable_to_configspace_test.py | 3 ++- mlos_bench/mlos_bench/util.py | 4 +++- .../optimizers/bayesian_optimizers/bayesian_optimizer.py | 5 ++++- mlos_core/mlos_core/optimizers/flaml_optimizer.py | 4 +++- mlos_core/mlos_core/optimizers/optimizer.py | 4 +++- mlos_core/mlos_core/optimizers/random_optimizer.py | 4 +++- mlos_core/mlos_core/tests/optimizers/one_hot_test.py | 3 ++- .../mlos_core/tests/optimizers/optimizer_multiobj_test.py | 4 +++- .../tests/spaces/adapters/space_adapter_factory_test.py | 3 ++- mlos_core/mlos_core/util.py | 3 ++- 39 files changed, 105 insertions(+), 39 deletions(-) diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index 576bc175c3..ab49e2baf8 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -52,7 +52,10 @@ def expand_vars( return self._dict def _expand_vars( - self, value: Any, extra_source_dict: Optional[Dict[str, Any]], use_os_env: bool + self, + value: Any, + extra_source_dict: Optional[Dict[str, Any]], + use_os_env: bool, ) -> Any: """Recursively expand $var strings in the currently operating dictionary.""" if isinstance(value, str): diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py index 0f09694489..d80ea862c9 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_fileshare.py @@ -71,7 +71,11 @@ def __init__( ) def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True + self, + params: dict, + remote_path: str, + local_path: str, + recursive: bool = True, ) -> None: super().download(params, remote_path, local_path, recursive) dir_client = self._share_client.get_directory_client(remote_path) diff --git a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py index 09d5986ef8..3d390645f5 100644 --- a/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py +++ b/mlos_bench/mlos_bench/services/remote/azure/azure_vm_services.py @@ -447,7 +447,10 @@ def reboot(self, params: dict, force: bool = False) -> Tuple["Status", dict]: return self.restart_host(params, force) def remote_exec( - self, script: Iterable[str], config: dict, env_params: dict + self, + script: Iterable[str], + config: dict, + env_params: dict, ) -> Tuple[Status, dict]: """ Run a command on Azure VM. diff --git a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py index 89888dec47..706764a1f1 100644 --- a/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py +++ b/mlos_bench/mlos_bench/services/remote/ssh/ssh_service.py @@ -163,7 +163,8 @@ def exit(self) -> None: self._cache_lock.release() async def get_client_connection( - self, connect_params: dict + self, + connect_params: dict, ) -> Tuple[SSHClientConnection, SshClient]: """ Gets a (possibly cached) client connection. diff --git a/mlos_bench/mlos_bench/services/types/config_loader_type.py b/mlos_bench/mlos_bench/services/types/config_loader_type.py index e29e5688ec..4eb473edff 100644 --- a/mlos_bench/mlos_bench/services/types/config_loader_type.py +++ b/mlos_bench/mlos_bench/services/types/config_loader_type.py @@ -49,7 +49,9 @@ def resolve_path(self, file_path: str, extra_paths: Optional[Iterable[str]] = No """ def load_config( - self, json_file_name: str, schema_type: Optional[ConfigSchema] + self, + json_file_name: str, + schema_type: Optional[ConfigSchema], ) -> Union[dict, List[dict]]: """ Load JSON config file. Search for a file relative to `_config_path` if the input diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index cd529c730c..9c0e88e3d5 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -400,7 +400,10 @@ def config(self, global_config: Optional[Dict[str, Any]] = None) -> Dict[str, An @abstractmethod def update( - self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + self, + status: Status, + timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: """ Update the storage with the results of the experiment. diff --git a/mlos_bench/mlos_bench/storage/sql/trial.py b/mlos_bench/mlos_bench/storage/sql/trial.py index 006f7761b8..6c2cf26cc7 100644 --- a/mlos_bench/mlos_bench/storage/sql/trial.py +++ b/mlos_bench/mlos_bench/storage/sql/trial.py @@ -47,7 +47,10 @@ def __init__( self._schema = schema def update( - self, status: Status, timestamp: datetime, metrics: Optional[Dict[str, Any]] = None + self, + status: Status, + timestamp: datetime, + metrics: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: # Make sure to convert the timestamp to UTC before storing it in the database. timestamp = utcify_timestamp(timestamp, origin="local") diff --git a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py index a4c85cb3da..bb9161144a 100644 --- a/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py +++ b/mlos_bench/mlos_bench/tests/config/storage/test_load_storage_config_examples.py @@ -37,7 +37,8 @@ def filter_configs(configs_to_filter: List[str]) -> List[str]: @pytest.mark.parametrize("config_path", configs) def test_load_storage_config_examples( - config_loader_service: ConfigPersistenceService, config_path: str + config_loader_service: ConfigPersistenceService, + config_path: str, ) -> None: """Tests loading a config example.""" config = config_loader_service.load_config(config_path, ConfigSchema.STORAGE) diff --git a/mlos_bench/mlos_bench/tests/conftest.py b/mlos_bench/mlos_bench/tests/conftest.py index 09b242ac12..a13c57a2cd 100644 --- a/mlos_bench/mlos_bench/tests/conftest.py +++ b/mlos_bench/mlos_bench/tests/conftest.py @@ -98,7 +98,8 @@ def docker_compose_project_name(short_testrun_uid: str) -> str: @pytest.fixture(scope="session") def docker_services_lock( - shared_temp_dir: str, short_testrun_uid: str + shared_temp_dir: str, + short_testrun_uid: str, ) -> InterProcessReaderWriterLock: """ Gets a pytest session lock for xdist workers to mark when they're using the docker diff --git a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py index 77a6bf5ad4..80463ea3d9 100644 --- a/mlos_bench/mlos_bench/tests/environments/composite_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/composite_env_test.py @@ -256,7 +256,8 @@ def test_nested_composite_env_params(nested_composite_env: CompositeEnv) -> None def test_nested_composite_env_setup( - nested_composite_env: CompositeEnv, tunable_groups: TunableGroups + nested_composite_env: CompositeEnv, + tunable_groups: TunableGroups, ) -> None: """Check that the child environments update their tunable parameters.""" tunable_groups.assign( diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 7f3b070109..4dc9ae43dc 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -118,7 +118,8 @@ def test_local_env_telemetry_no_header( ) # pylint: disable=line-too-long # noqa @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_local_env_telemetry_wrong_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Read the telemetry data with incorrect header.""" ts1 = datetime.now(zone_info) diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py index 08ce0790bc..9c40d422e7 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_fileshare_env_test.py @@ -27,7 +27,8 @@ def mock_fileshare_service() -> MockFileShareService: @pytest.fixture def local_fileshare_env( - tunable_groups: TunableGroups, mock_fileshare_service: MockFileShareService + tunable_groups: TunableGroups, + mock_fileshare_service: MockFileShareService, ) -> LocalFileShareEnv: """Create a LocalFileShareEnv instance.""" env = LocalFileShareEnv( diff --git a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py index b29f1098d7..3a82d8dfd3 100644 --- a/mlos_bench/mlos_bench/tests/environments/mock_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/mock_env_test.py @@ -50,7 +50,10 @@ def test_mock_env_no_noise(mock_env_no_noise: MockEnv, tunable_groups: TunableGr ], ) def test_mock_env_assign( - mock_env: MockEnv, tunable_groups: TunableGroups, tunable_values: dict, expected_score: float + mock_env: MockEnv, + tunable_groups: TunableGroups, + tunable_values: dict, + expected_score: float, ) -> None: """Check the benchmark values of the mock environment after the assignment.""" with mock_env as env_context: diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index d6f5b8cfd5..bb16190bfb 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -91,7 +91,8 @@ def test_launch_main_app_bench(root_path: str, local_exec_service: LocalExecServ def test_launch_main_app_bench_values( - root_path: str, local_exec_service: LocalExecService + root_path: str, + local_exec_service: LocalExecService, ) -> None: """Run mlos_bench command-line application with mock benchmark config and user- specified tunable values and check the results in the log. diff --git a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py index ee41f95b13..05305de50b 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/mock_opt_test.py @@ -83,7 +83,8 @@ def test_mock_optimizer(mock_opt: MockOptimizer, mock_configurations: list) -> N def test_mock_optimizer_no_defaults( - mock_opt_no_defaults: MockOptimizer, mock_configurations_no_defaults: list + mock_opt_no_defaults: MockOptimizer, + mock_configurations_no_defaults: list, ) -> None: """Make sure that mock optimizer produces consistent suggestions.""" score = _optimize(mock_opt_no_defaults, mock_configurations_no_defaults) diff --git a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py index 1596d4997d..db46189e44 100644 --- a/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py +++ b/mlos_bench/mlos_bench/tests/optimizers/toy_optimization_loop_test.py @@ -78,7 +78,8 @@ def test_mock_optimization_loop(mock_env_no_noise: MockEnv, mock_opt: MockOptimi def test_mock_optimization_loop_no_defaults( - mock_env_no_noise: MockEnv, mock_opt_no_defaults: MockOptimizer + mock_env_no_noise: MockEnv, + mock_opt_no_defaults: MockOptimizer, ) -> None: """Toy optimization loop with mock environment and optimizer.""" (score, tunables) = _optimize(mock_env_no_noise, mock_opt_no_defaults) diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index fa1adc9935..9b24e987db 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -140,7 +140,9 @@ def test_download_folder_recursive( @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.path.isdir") def test_upload_file( - mock_isdir: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService + mock_isdir: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, ) -> None: filename = "test.csv" remote_folder = "a/remote/folder" diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py index 96cdc9f1d1..ad7bae26ee 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/conftest.py @@ -27,7 +27,8 @@ def config_persistence_service() -> ConfigPersistenceService: @pytest.fixture def azure_auth_service( - config_persistence_service: ConfigPersistenceService, monkeypatch: pytest.MonkeyPatch + config_persistence_service: ConfigPersistenceService, + monkeypatch: pytest.MonkeyPatch, ) -> AzureAuthService: """Creates a dummy AzureAuthService for tests that require it.""" auth = AzureAuthService(config={}, global_config={}, parent=config_persistence_service) diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index abeb35f091..7cbe47a07c 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -39,7 +39,11 @@ def upload( self._upload.append((local_path, remote_path)) def download( - self, params: dict, remote_path: str, local_path: str, recursive: bool = True + self, + params: dict, + remote_path: str, + local_path: str, + recursive: bool = True, ) -> None: self._download.append((remote_path, local_path)) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index 913b045a76..e9bd43124b 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -122,7 +122,8 @@ def alt_test_server( @pytest.fixture(scope="session") def reboot_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices + ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices, ) -> SshTestServerInfo: """ Fixture for getting the third ssh test server info from the docker-compose.yml. diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index a6a7c6149b..a319f2e5bf 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -150,7 +150,8 @@ def test_ssh_fileshare_recursive( @requires_docker def test_ssh_fileshare_download_file_dne( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService + ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService, ) -> None: """Test the SshFileShareService single file download that doesn't exist.""" with ssh_fileshare_service: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index e6ef30db6a..9931df7ef7 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -31,7 +31,9 @@ def test_exp_data_root_env_config( def test_exp_trial_data_objectives( - storage: Storage, exp_storage: Storage.Experiment, tunable_groups: TunableGroups + storage: Storage, + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, ) -> None: """Start a new trial and check the storage for the trial data.""" diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index 0cbd02ae97..259d880f92 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -118,7 +118,9 @@ def test_exp_trial_update_twice( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_pending_3( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """ Start THREE trials, let one succeed, another one fail and keep one not updated. diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index fa26245e78..ec8515b412 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -216,7 +216,8 @@ def exp_no_tunables_data( @pytest.fixture def mixed_numerics_exp_data( - storage: SqlStorage, mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment + storage: SqlStorage, + mixed_numerics_exp_storage_with_trials: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData with mixed numerical tunable types.""" return storage.experiments[mixed_numerics_exp_storage_with_trials.experiment_id] diff --git a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py index aeb3d9fbee..72f73724db 100644 --- a/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/storage/trial_telemetry_test.py @@ -42,7 +42,7 @@ def zoned_telemetry_data(zone_info: Optional[tzinfo]) -> List[Tuple[datetime, st def _telemetry_str( - data: List[Tuple[datetime, str, Any]] + data: List[Tuple[datetime, str, Any]], ) -> List[Tuple[datetime, str, Optional[str]]]: """Convert telemetry values to strings.""" # All retrieved timestamps should have been converted to UTC. diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 20ed746462..4173ae2761 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -41,7 +41,8 @@ def test_trial_data_no_tunables_config_data(exp_no_tunables_data: ExperimentData def test_mixed_numerics_exp_trial_data( - mixed_numerics_exp_data: ExperimentData, mixed_numerics_tunable_groups: TunableGroups + mixed_numerics_exp_data: ExperimentData, + mixed_numerics_tunable_groups: TunableGroups, ) -> None: """Tests that data type conversions are retained when loading experiment data with mixed numeric tunable types. diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py index b8d83d5c32..faa61e5286 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_trial_group_data_test.py @@ -53,7 +53,8 @@ def test_exp_trial_data_tunable_config_trial_group_id(exp_data: ExperimentData) def test_tunable_config_trial_group_results_df( - exp_data: ExperimentData, tunable_groups: TunableGroups + exp_data: ExperimentData, + tunable_groups: TunableGroups, ) -> None: """Tests the results_df property of the TunableConfigTrialGroup.""" tunable_config_id = 2 diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py index 6e4d9c3658..ae22094baa 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_group_indexing_test.py @@ -9,7 +9,8 @@ def test_tunable_group_indexing( - tunable_groups: TunableGroups, tunable_categorical: Tunable + tunable_groups: TunableGroups, + tunable_categorical: Tunable, ) -> None: """Check that various types of indexing work for the tunable group.""" # Check that the "in" operator works. diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py index 97b8ea8c41..91e387f92b 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_distr_test.py @@ -44,7 +44,9 @@ ], ) def test_convert_numerical_distributions( - param_type: str, distr_name: DistributionName, distr_params: dict + param_type: str, + distr_name: DistributionName, + distr_params: dict, ) -> None: """Convert a numerical Tunable with explicit distribution to ConfigSpace.""" tunable_name = "x" diff --git a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py index 5350e5e4eb..dce3e366a6 100644 --- a/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py +++ b/mlos_bench/mlos_bench/tests/tunables/tunable_to_configspace_test.py @@ -136,7 +136,8 @@ def test_tunable_groups_to_hyperparameters(tunable_groups: TunableGroups) -> Non def test_tunable_groups_to_configspace( - tunable_groups: TunableGroups, configuration_space: ConfigurationSpace + tunable_groups: TunableGroups, + configuration_space: ConfigurationSpace, ) -> None: """Check the conversion of the entire TunableGroups collection to a single ConfigurationSpace object. diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index 37170b06c0..dc86901651 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -373,7 +373,9 @@ def utcify_nullable_timestamp( def datetime_parser( - datetime_col: pandas.Series, *, origin: Literal["utc", "local"] + datetime_col: pandas.Series, + *, + origin: Literal["utc", "local"], ) -> pandas.Series: """ Attempt to convert a pandas column to a datetime format. diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 11669d4d79..2f7f71672b 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -37,7 +37,10 @@ def surrogate_predict( @abstractmethod def acquisition_function( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, ) -> npt.NDArray: """ Invokes the acquisition function from this Bayesian optimizer for the given diff --git a/mlos_core/mlos_core/optimizers/flaml_optimizer.py b/mlos_core/mlos_core/optimizers/flaml_optimizer.py index 958f98e02e..50def8bc80 100644 --- a/mlos_core/mlos_core/optimizers/flaml_optimizer.py +++ b/mlos_core/mlos_core/optimizers/flaml_optimizer.py @@ -135,7 +135,9 @@ def _register( ) def _suggest( - self, *, context: Optional[pd.DataFrame] = None + self, + *, + context: Optional[pd.DataFrame] = None, ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index ddd4a466db..70a2010a45 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -252,7 +252,9 @@ def get_observations(self) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.Data return (configs, scores, contexts if len(contexts.columns) > 0 else None) def get_best_observations( - self, *, n_max: int = 1 + self, + *, + n_max: int = 1, ) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]: """ Get the N best observations so far as a triplet of DataFrames (config, score, diff --git a/mlos_core/mlos_core/optimizers/random_optimizer.py b/mlos_core/mlos_core/optimizers/random_optimizer.py index ddee68f345..661a48a373 100644 --- a/mlos_core/mlos_core/optimizers/random_optimizer.py +++ b/mlos_core/mlos_core/optimizers/random_optimizer.py @@ -58,7 +58,9 @@ def _register( # should we pop them from self.pending_observations? def _suggest( - self, *, context: Optional[pd.DataFrame] = None + self, + *, + context: Optional[pd.DataFrame] = None, ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index da5a3d492a..0325c764d2 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -137,7 +137,8 @@ def test_round_trip_series(optimizer: BaseOptimizer, series: pd.DataFrame) -> No def test_round_trip_reverse_data_frame( - optimizer: BaseOptimizer, one_hot_data_frame: npt.NDArray + optimizer: BaseOptimizer, + one_hot_data_frame: npt.NDArray, ) -> None: """Round-trip test for one-hot-decoding and then encoding of a numpy array.""" round_trip = optimizer._to_1hot(config=optimizer._from_1hot(config=one_hot_data_frame)) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index c1f743dd03..6ca009316f 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -54,7 +54,9 @@ def test_multi_target_opt_wrong_weights( ], ) def test_multi_target_opt( - objective_weights: Optional[List[float]], optimizer_class: Type[BaseOptimizer], kwargs: dict + objective_weights: Optional[List[float]], + optimizer_class: Type[BaseOptimizer], + kwargs: dict, ) -> None: """Toy multi-target optimization problem to test the optimizers with mixed numeric types to ensure that original dtypes are retained. diff --git a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py index 188a0300e7..6dc35441dc 100644 --- a/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py +++ b/mlos_core/mlos_core/tests/spaces/adapters/space_adapter_factory_test.py @@ -48,7 +48,8 @@ def test_concrete_optimizer_type(space_adapter_type: SpaceAdapterType) -> None: ], ) def test_create_space_adapter_with_factory_method( - space_adapter_type: Optional[SpaceAdapterType], kwargs: Optional[dict] + space_adapter_type: Optional[SpaceAdapterType], + kwargs: Optional[dict], ) -> None: # Start defining a ConfigurationSpace for the Optimizer to search. input_space = CS.ConfigurationSpace(seed=1234) diff --git a/mlos_core/mlos_core/util.py b/mlos_core/mlos_core/util.py index 50c6880f87..0a66c5a837 100644 --- a/mlos_core/mlos_core/util.py +++ b/mlos_core/mlos_core/util.py @@ -28,7 +28,8 @@ def config_to_dataframe(config: Configuration) -> pd.DataFrame: def normalize_config( - config_space: ConfigurationSpace, config: Union[Configuration, dict] + config_space: ConfigurationSpace, + config: Union[Configuration, dict], ) -> Configuration: """ Convert a dictionary to a valid ConfigSpace configuration. From b51153f53ef22153e1e1cbbd279e527c3a71163b Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:25:08 +0000 Subject: [PATCH 51/54] multiline args --- mlos_bench/mlos_bench/dict_templater.py | 5 +++- .../local/local_env_telemetry_test.py | 3 ++- .../mlos_bench/tests/launcher_run_test.py | 5 +++- .../remote/azure/azure_fileshare_test.py | 12 +++++++--- .../remote/mock/mock_fileshare_service.py | 6 ++++- .../tests/services/remote/ssh/fixtures.py | 3 ++- .../services/remote/ssh/test_ssh_fileshare.py | 6 +++-- .../mlos_bench/tests/storage/exp_data_test.py | 3 ++- .../mlos_bench/tests/storage/exp_load_test.py | 24 ++++++++++++++----- .../mlos_bench/tests/storage/sql/fixtures.py | 9 ++++--- .../tests/storage/tunable_config_data_test.py | 3 ++- mlos_bench/mlos_bench/util.py | 17 +++++++++---- .../bayesian_optimizers/bayesian_optimizer.py | 5 +++- mlos_core/mlos_core/optimizers/optimizer.py | 9 +++++-- .../tests/optimizers/one_hot_test.py | 16 +++++++++---- .../optimizers/optimizer_multiobj_test.py | 3 ++- 16 files changed, 96 insertions(+), 33 deletions(-) diff --git a/mlos_bench/mlos_bench/dict_templater.py b/mlos_bench/mlos_bench/dict_templater.py index ab49e2baf8..e209f12bed 100644 --- a/mlos_bench/mlos_bench/dict_templater.py +++ b/mlos_bench/mlos_bench/dict_templater.py @@ -29,7 +29,10 @@ def __init__(self, source_dict: Dict[str, Any]): self._dict: Dict[str, Any] = {} def expand_vars( - self, *, extra_source_dict: Optional[Dict[str, Any]] = None, use_os_env: bool = False + self, + *, + extra_source_dict: Optional[Dict[str, Any]] = None, + use_os_env: bool = False, ) -> Dict[str, Any]: """ Expand the template variables in the destination dictionary. diff --git a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py index 4dc9ae43dc..9cda41f14d 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/local_env_telemetry_test.py @@ -73,7 +73,8 @@ def test_local_env_telemetry(tunable_groups: TunableGroups, zone_info: Optional[ # FIXME: This fails with zone_info = None when run with `TZ="America/Chicago pytest -n0 ...` @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_local_env_telemetry_no_header( - tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Read the telemetry data with no header.""" ts1 = datetime.now(zone_info) diff --git a/mlos_bench/mlos_bench/tests/launcher_run_test.py b/mlos_bench/mlos_bench/tests/launcher_run_test.py index bb16190bfb..1ae5af7e11 100644 --- a/mlos_bench/mlos_bench/tests/launcher_run_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_run_test.py @@ -38,7 +38,10 @@ def local_exec_service() -> LocalExecService: def _launch_main_app( - root_path: str, local_exec_service: LocalExecService, cli_config: str, re_expected: List[str] + root_path: str, + local_exec_service: LocalExecService, + cli_config: str, + re_expected: List[str], ) -> None: """Run mlos_bench command-line application with given config and check the results in the log. diff --git a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py index 9b24e987db..79090a2f5f 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py +++ b/mlos_bench/mlos_bench/tests/services/remote/azure/azure_fileshare_test.py @@ -17,7 +17,9 @@ @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") def test_download_file( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, ) -> None: filename = "test.csv" remote_folder = "a/remote/folder" @@ -71,7 +73,9 @@ def make_dir_client_returns(remote_folder: str) -> dict: @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") def test_download_folder_non_recursive( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, ) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" @@ -103,7 +107,9 @@ def test_download_folder_non_recursive( @patch("mlos_bench.services.remote.azure.azure_fileshare.open") @patch("mlos_bench.services.remote.azure.azure_fileshare.os.makedirs") def test_download_folder_recursive( - mock_makedirs: MagicMock, mock_open: MagicMock, azure_fileshare: AzureFileShareService + mock_makedirs: MagicMock, + mock_open: MagicMock, + azure_fileshare: AzureFileShareService, ) -> None: remote_folder = "a/remote/folder" local_folder = "some/local/folder" diff --git a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py index 7cbe47a07c..2d227e635e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py +++ b/mlos_bench/mlos_bench/tests/services/remote/mock/mock_fileshare_service.py @@ -34,7 +34,11 @@ def __init__( self._download: List[Tuple[str, str]] = [] def upload( - self, params: dict, local_path: str, remote_path: str, recursive: bool = True + self, + params: dict, + local_path: str, + remote_path: str, + recursive: bool = True, ) -> None: self._upload.append((local_path, remote_path)) diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py index e9bd43124b..f4042cf62f 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/fixtures.py @@ -97,7 +97,8 @@ def ssh_test_server( @pytest.fixture(scope="session") def alt_test_server( - ssh_test_server: SshTestServerInfo, locked_docker_services: DockerServices + ssh_test_server: SshTestServerInfo, + locked_docker_services: DockerServices, ) -> SshTestServerInfo: """ Fixture for getting the second ssh test server info from the docker-compose.yml. diff --git a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py index a319f2e5bf..c0bb730a1e 100644 --- a/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py +++ b/mlos_bench/mlos_bench/tests/services/remote/ssh/test_ssh_fileshare.py @@ -51,7 +51,8 @@ def closeable_temp_file(**kwargs: Any) -> Generator[_TemporaryFileWrapper, None, @requires_docker def test_ssh_fileshare_single_file( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService + ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService, ) -> None: """Test the SshFileShareService single file download/upload.""" with ssh_fileshare_service: @@ -92,7 +93,8 @@ def test_ssh_fileshare_single_file( @requires_docker def test_ssh_fileshare_recursive( - ssh_test_server: SshTestServerInfo, ssh_fileshare_service: SshFileShareService + ssh_test_server: SshTestServerInfo, + ssh_fileshare_service: SshFileShareService, ) -> None: """Test the SshFileShareService recursive download/upload.""" with ssh_fileshare_service: diff --git a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py index 9931df7ef7..256c5d3b38 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_data_test.py @@ -19,7 +19,8 @@ def test_load_empty_exp_data(storage: Storage, exp_storage: Storage.Experiment) def test_exp_data_root_env_config( - exp_storage: Storage.Experiment, exp_data: ExperimentData + exp_storage: Storage.Experiment, + exp_data: ExperimentData, ) -> None: """Tests the root_env_config property of ExperimentData.""" # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py index 259d880f92..6a7d05bb2a 100644 --- a/mlos_bench/mlos_bench/tests/storage/exp_load_test.py +++ b/mlos_bench/mlos_bench/tests/storage/exp_load_test.py @@ -32,7 +32,9 @@ def test_exp_pending_empty(exp_storage: Storage.Experiment) -> None: @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_pending( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Start a trial and check that it is pending.""" trial = exp_storage.new_trial(tunable_groups) @@ -43,7 +45,9 @@ def test_exp_trial_pending( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_pending_many( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Start THREE trials and check that both are pending.""" config1 = tunable_groups.copy().assign({"idle": "mwait"}) @@ -63,7 +67,9 @@ def test_exp_trial_pending_many( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_pending_fail( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Start a trial, fail it, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) @@ -74,7 +80,9 @@ def test_exp_trial_pending_fail( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_success( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Start a trial, finish it successfully, and and check that it is NOT pending.""" trial = exp_storage.new_trial(tunable_groups) @@ -85,7 +93,9 @@ def test_exp_trial_success( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_update_categ( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Update the trial with multiple metrics, some of which are categorical.""" trial = exp_storage.new_trial(tunable_groups) @@ -107,7 +117,9 @@ def test_exp_trial_update_categ( @pytest.mark.parametrize(("zone_info"), ZONE_INFO) def test_exp_trial_update_twice( - exp_storage: Storage.Experiment, tunable_groups: TunableGroups, zone_info: Optional[tzinfo] + exp_storage: Storage.Experiment, + tunable_groups: TunableGroups, + zone_info: Optional[tzinfo], ) -> None: """Update the trial status twice and receive an error.""" trial = exp_storage.new_trial(tunable_groups) diff --git a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py index ec8515b412..8a9065e436 100644 --- a/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py +++ b/mlos_bench/mlos_bench/tests/storage/sql/fixtures.py @@ -107,7 +107,8 @@ def mixed_numerics_exp_storage( def _dummy_run_exp( - exp: SqlStorage.Experiment, tunable_name: Optional[str] + exp: SqlStorage.Experiment, + tunable_name: Optional[str], ) -> SqlStorage.Experiment: """Generates data by doing a simulated run of the given experiment.""" # Add some trials to that experiment. @@ -200,7 +201,8 @@ def mixed_numerics_exp_storage_with_trials( @pytest.fixture def exp_data( - storage: SqlStorage, exp_storage_with_trials: SqlStorage.Experiment + storage: SqlStorage, + exp_storage_with_trials: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData.""" return storage.experiments[exp_storage_with_trials.experiment_id] @@ -208,7 +210,8 @@ def exp_data( @pytest.fixture def exp_no_tunables_data( - storage: SqlStorage, exp_no_tunables_storage_with_trials: SqlStorage.Experiment + storage: SqlStorage, + exp_no_tunables_storage_with_trials: SqlStorage.Experiment, ) -> ExperimentData: """Test fixture for ExperimentData with no tunable configs.""" return storage.experiments[exp_no_tunables_storage_with_trials.experiment_id] diff --git a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py index 4173ae2761..755fc0205a 100644 --- a/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py +++ b/mlos_bench/mlos_bench/tests/storage/tunable_config_data_test.py @@ -9,7 +9,8 @@ def test_trial_data_tunable_config_data( - exp_data: ExperimentData, tunable_groups: TunableGroups + exp_data: ExperimentData, + tunable_groups: TunableGroups, ) -> None: """Check expected return values for TunableConfigData.""" trial_id = 1 diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index dc86901651..64d3600966 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -70,7 +70,10 @@ def preprocess_dynamic_configs(*, dest: dict, source: Optional[dict] = None) -> def merge_parameters( - *, dest: dict, source: Optional[dict] = None, required_keys: Optional[Iterable[str]] = None + *, + dest: dict, + source: Optional[dict] = None, + required_keys: Optional[Iterable[str]] = None, ) -> dict: """ Merge the source config dict into the destination config. Pick from the source @@ -131,7 +134,8 @@ def path_join(*args: str, abs_path: bool = False) -> str: def prepare_class_load( - config: dict, global_config: Optional[Dict[str, Any]] = None + config: dict, + global_config: Optional[Dict[str, Any]] = None, ) -> Tuple[str, Dict[str, Any]]: """ Extract the class instantiation parameters from the configuration. @@ -188,7 +192,10 @@ def get_class_from_name(class_name: str) -> type: # FIXME: Technically, this should return a type "class_name" derived from "base_class". def instantiate_from_config( - base_class: Type[BaseTypeVar], class_name: str, *args: Any, **kwargs: Any + base_class: Type[BaseTypeVar], + class_name: str, + *args: Any, + **kwargs: Any, ) -> BaseTypeVar: """ Factory method for a new class instantiated from config. @@ -361,7 +368,9 @@ def utcify_timestamp(timestamp: datetime, *, origin: Literal["utc", "local"]) -> def utcify_nullable_timestamp( - timestamp: Optional[datetime], *, origin: Literal["utc", "local"] + timestamp: Optional[datetime], + *, + origin: Literal["utc", "local"], ) -> Optional[datetime]: """A nullable version of utcify_timestamp.""" return utcify_timestamp(timestamp, origin=origin) if timestamp is not None else None diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py index 2f7f71672b..a39a5516e8 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/bayesian_optimizer.py @@ -18,7 +18,10 @@ class BaseBayesianOptimizer(BaseOptimizer, metaclass=ABCMeta): @abstractmethod def surrogate_predict( - self, *, configs: pd.DataFrame, context: Optional[pd.DataFrame] = None + self, + *, + configs: pd.DataFrame, + context: Optional[pd.DataFrame] = None, ) -> npt.NDArray: """ Obtain a prediction from this Bayesian optimizer's surrogate model for the given diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 70a2010a45..4152e3c4c0 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -144,7 +144,10 @@ def _register( pass # pylint: disable=unnecessary-pass # pragma: no cover def suggest( - self, *, context: Optional[pd.DataFrame] = None, defaults: bool = False + self, + *, + context: Optional[pd.DataFrame] = None, + defaults: bool = False, ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Wrapper method, which employs the space adapter (if any), after suggesting a new @@ -185,7 +188,9 @@ def suggest( @abstractmethod def _suggest( - self, *, context: Optional[pd.DataFrame] = None + self, + *, + context: Optional[pd.DataFrame] = None, ) -> Tuple[pd.DataFrame, Optional[pd.DataFrame]]: """ Suggests a new configuration. diff --git a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py index 0325c764d2..c910f60fc5 100644 --- a/mlos_core/mlos_core/tests/optimizers/one_hot_test.py +++ b/mlos_core/mlos_core/tests/optimizers/one_hot_test.py @@ -91,28 +91,36 @@ def optimizer(configuration_space: CS.ConfigurationSpace) -> BaseOptimizer: def test_to_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray + optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray, ) -> None: """Toy problem to test one-hot encoding of dataframe.""" assert optimizer._to_1hot(config=data_frame) == pytest.approx(one_hot_data_frame) def test_to_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray + optimizer: BaseOptimizer, + series: pd.Series, + one_hot_series: npt.NDArray, ) -> None: """Toy problem to test one-hot encoding of series.""" assert optimizer._to_1hot(config=series) == pytest.approx(one_hot_series) def test_from_1hot_data_frame( - optimizer: BaseOptimizer, data_frame: pd.DataFrame, one_hot_data_frame: npt.NDArray + optimizer: BaseOptimizer, + data_frame: pd.DataFrame, + one_hot_data_frame: npt.NDArray, ) -> None: """Toy problem to test one-hot decoding of dataframe.""" assert optimizer._from_1hot(config=one_hot_data_frame).to_dict() == data_frame.to_dict() def test_from_1hot_series( - optimizer: BaseOptimizer, series: pd.Series, one_hot_series: npt.NDArray + optimizer: BaseOptimizer, + series: pd.Series, + one_hot_series: npt.NDArray, ) -> None: """Toy problem to test one-hot decoding of series.""" one_hot_df = optimizer._from_1hot(config=one_hot_series) diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py index 6ca009316f..748fd1cc82 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_multiobj_test.py @@ -25,7 +25,8 @@ ], ) def test_multi_target_opt_wrong_weights( - optimizer_class: Type[BaseOptimizer], kwargs: dict + optimizer_class: Type[BaseOptimizer], + kwargs: dict, ) -> None: """Make sure that the optimizer raises an error if the number of objective weights does not match the number of optimization targets. From ad1982e3b6d909c675cbe7ca179653bea3651765 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:34:39 +0000 Subject: [PATCH 52/54] tweaks --- .../tests/environments/base_env_test.py | 11 ++++++++++- .../local/composite_local_env_test.py | 18 ++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/mlos_bench/mlos_bench/tests/environments/base_env_test.py b/mlos_bench/mlos_bench/tests/environments/base_env_test.py index e7e17e6df7..04f9e8c54c 100644 --- a/mlos_bench/mlos_bench/tests/environments/base_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/base_env_test.py @@ -24,7 +24,16 @@ def test_expand_groups() -> None: """Check the dollar variable expansion for tunable groups.""" - assert Environment._expand_groups(["begin", "$list", "$empty", "$str", "end"], _GROUPS) == [ + assert Environment._expand_groups( + [ + "begin", + "$list", + "$empty", + "$str", + "end", + ], + _GROUPS, + ) == [ "begin", "c", "d", diff --git a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py index 1f3cf66110..4d15a6fcee 100644 --- a/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py +++ b/mlos_bench/mlos_bench/tests/environments/local/composite_local_env_test.py @@ -65,9 +65,12 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["errors", "reads"], "shell_env_params": [ - "latency", # const_args overridden by the composite env - "errors", # Comes from the parent const_args - "reads", # const_args overridden by the global config + # const_args overridden by the composite env + "latency", + # Comes from the parent const_args + "errors", + # const_args overridden by the global config + "reads", ], "run": [ "echo 'metric,value' > output.csv", @@ -89,9 +92,12 @@ def test_composite_env(tunable_groups: TunableGroups, zone_info: Optional[tzinfo }, "required_args": ["writes"], "shell_env_params": [ - "throughput", # const_args overridden by the composite env - "score", # Comes from the local const_args - "writes", # Comes straight from the global config + # const_args overridden by the composite env + "throughput", + # Comes from the local const_args + "score", + # Comes straight from the global config + "writes", ], "run": [ "echo 'metric,value' > output.csv", From fb56118b5341d551b5818f78286b1dd0ca7f5abe Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Wed, 10 Jul 2024 21:44:45 +0000 Subject: [PATCH 53/54] tweaks --- mlos_bench/mlos_bench/os_environ.py | 5 ++--- mlos_bench/mlos_bench/tunables/tunable.py | 18 +++++++++++++++--- .../mlos_bench/tunables/tunable_groups.py | 7 +++---- 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/mlos_bench/mlos_bench/os_environ.py b/mlos_bench/mlos_bench/os_environ.py index 348cf1ffa0..f750f12038 100644 --- a/mlos_bench/mlos_bench/os_environ.py +++ b/mlos_bench/mlos_bench/os_environ.py @@ -22,9 +22,8 @@ from typing_extensions import TypeAlias if sys.version_info >= (3, 9): - EnvironType: TypeAlias = os._Environ[ - str - ] # pylint: disable=protected-access,disable=unsubscriptable-object + # pylint: disable=protected-access,disable=unsubscriptable-object + EnvironType: TypeAlias = os._Environ[str] else: EnvironType: TypeAlias = os._Environ # pylint: disable=protected-access diff --git a/mlos_bench/mlos_bench/tunables/tunable.py b/mlos_bench/mlos_bench/tunables/tunable.py index 1a6d3a804b..8f9bb48bff 100644 --- a/mlos_bench/mlos_bench/tunables/tunable.py +++ b/mlos_bench/mlos_bench/tunables/tunable.py @@ -327,19 +327,31 @@ def value(self, value: TunableValue) -> TunableValue: coerced_value = self.dtype(value) except Exception: _LOG.error( - "Impossible conversion: %s %s <- %s %s", self._type, self._name, type(value), value + "Impossible conversion: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, ) raise if self._type == "int" and isinstance(value, float) and value != coerced_value: _LOG.error( - "Loss of precision: %s %s <- %s %s", self._type, self._name, type(value), value + "Loss of precision: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, ) raise ValueError(f"Loss of precision: {self._name}={value}") if not self.is_valid(coerced_value): _LOG.error( - "Invalid assignment: %s %s <- %s %s", self._type, self._name, type(value), value + "Invalid assignment: %s %s <- %s %s", + self._type, + self._name, + type(value), + value, ) raise ValueError(f"Invalid value for the tunable: {self._name}={value}") diff --git a/mlos_bench/mlos_bench/tunables/tunable_groups.py b/mlos_bench/mlos_bench/tunables/tunable_groups.py index b3e3698c61..da56eb79ac 100644 --- a/mlos_bench/mlos_bench/tunables/tunable_groups.py +++ b/mlos_bench/mlos_bench/tunables/tunable_groups.py @@ -26,9 +26,8 @@ def __init__(self, config: Optional[dict] = None): if config is None: config = {} ConfigSchema.TUNABLE_PARAMS.validate(config) - self._index: Dict[str, CovariantTunableGroup] = ( - {} - ) # Index (Tunable id -> CovariantTunableGroup) + # Index (Tunable id -> CovariantTunableGroup) + self._index: Dict[str, CovariantTunableGroup] = {} self._tunable_groups: Dict[str, CovariantTunableGroup] = {} for name, group_config in config.items(): self._add_group(CovariantTunableGroup(name, group_config)) @@ -123,7 +122,7 @@ def merge(self, tunables: "TunableGroups") -> "TunableGroups": if not self._tunable_groups[group.name].equals_defaults(group): raise ValueError( f"Overlapping covariant tunable group name {group.name} " - + "in {self._tunable_groups[group.name]} and {tunables}" + "in {self._tunable_groups[group.name]} and {tunables}" ) return self From 8afeec649db7312b693d5fc4c14d4afe2cec2e83 Mon Sep 17 00:00:00 2001 From: Brian Kroth Date: Fri, 12 Jul 2024 19:06:22 +0000 Subject: [PATCH 54/54] format --- mlos_core/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlos_core/setup.py b/mlos_core/setup.py index 2bd5ade3d1..e4c792e270 100644 --- a/mlos_core/setup.py +++ b/mlos_core/setup.py @@ -97,7 +97,7 @@ def _get_long_desc_from_readme(base_url: str) -> dict: 'pandas >= 2.2.0;python_version>="3.9"', 'Bottleneck > 1.3.5;python_version>="3.9"', 'pandas >= 1.0.3;python_version<"3.9"', - "ConfigSpace==0.7.1", # Temporarily restrict ConfigSpace version. + "ConfigSpace==0.7.1", # Temporarily restrict ConfigSpace version. ], extras_require=extra_requires, **_get_long_desc_from_readme("https://github.com/microsoft/MLOS/tree/main/mlos_core"),