diff --git a/bencher/bencher.py b/bencher/bencher.py index 6a973ef0..36b7f98e 100644 --- a/bencher/bencher.py +++ b/bencher/bencher.py @@ -315,8 +315,21 @@ def plot_sweep( else: const_vars = deepcopy(const_vars) - for i in range(len(input_vars)): - input_vars[i] = self.convert_vars_to_params(input_vars[i], "input") + if isinstance(input_vars, dict): + input_lists = [] + for k, v in input_vars.items(): + param_var = self.convert_vars_to_params(k, "input") + if isinstance(v, list): + assert len(v) > 0 + param_var = param_var.with_sample_values(v) + else: + raise RuntimeError("Unsupported type") + input_lists.append(param_var) + + input_vars = input_lists + else: + for i in range(len(input_vars)): + input_vars[i] = self.convert_vars_to_params(input_vars[i], "input") for i in range(len(result_vars)): result_vars[i] = self.convert_vars_to_params(result_vars[i], "result") @@ -484,6 +497,10 @@ def convert_vars_to_params(self, variable: param.Parameter, var_type: str): """ if isinstance(variable, str): variable = self.worker_class_instance.param.objects(instance=False)[variable] + if isinstance(variable, tuple): + variable = self.worker_class_instance.param.objects(instance=False)[ + variable[0] + ].with_sample_values(variable[1]) if not isinstance(variable, param.Parameter): raise TypeError( f"You need to use {var_type}_vars =[{self.worker_input_cfg}.param.your_variable], instead of {var_type}_vars =[{self.worker_input_cfg}.your_variable]" diff --git a/bencher/example/example_custom_sweep2.py b/bencher/example/example_custom_sweep2.py new file mode 100644 index 00000000..ec10df18 --- /dev/null +++ b/bencher/example/example_custom_sweep2.py @@ -0,0 +1,40 @@ +import bencher as bch + + +class Square(bch.ParametrizedSweep): + """An example of a datatype with an integer and float parameter""" + + x = bch.FloatSweep(default=0, bounds=[0, 6]) + + result = bch.ResultVar("ul", doc="Square of x") + + def __call__(self, **kwargs) -> dict: + self.update_params_from_kwargs(**kwargs) + self.result = self.x * self.x + return self.get_results_values_as_dict() + + +def example_custom_sweep2( + run_cfg: bch.BenchRunCfg = None, report: bch.BenchReport = None +) -> bch.Bench: + """This example shows how to define a custom set of value to sample from intead of a uniform sweep + + Args: + run_cfg (BenchRunCfg): configuration of how to perform the param sweep + + Returns: + Bench: results of the parameter sweep + """ + + bench = Square().to_bench(run_cfg=run_cfg, report=report) + + # These are all equivalent + bench.plot_sweep(input_vars=[Square.param.x.with_sample_values([0, 1, 2])]) + bench.plot_sweep(input_vars=dict(x=[2, 3, 4])) + bench.plot_sweep(input_vars=[("x", [3, 4, 5])]) + + return bench + + +if __name__ == "__main__": + example_custom_sweep2().report.show() diff --git a/pixi.lock b/pixi.lock index 41c6b469..eda9d1ed 100644 --- a/pixi.lock +++ b/pixi.lock @@ -22,7 +22,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.3-hab00c5b_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda @@ -112,7 +112,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl - pypi: . py310: channels: @@ -135,7 +135,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.10.14-hd12c33a_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda @@ -227,7 +227,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl - pypi: . py311: channels: @@ -251,7 +251,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.11.9-hb806964_0_cpython.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda @@ -341,7 +341,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl - pypi: . packages: - kind: conda @@ -896,9 +896,9 @@ packages: requires_python: '>=3.7' - kind: pypi name: holobench - version: 1.23.0 + version: 1.24.0 path: . - sha256: 76ce5d13f28e35b37db41f02f45ee8dce425e1ab53b7bc2bc4a97322d908309a + sha256: 5d6bfc4999f8aae6d6207b11e7bcb5649ea91191c05fb7f739f2bfc03d926d15 requires_dist: - holoviews>=1.15,<=1.18.3 - numpy>=1.0,<=1.26.4 @@ -1424,6 +1424,7 @@ packages: constrains: - binutils_impl_linux-64 2.40 license: GPL-3.0-only + license_family: GPL purls: [] size: 708179 timestamp: 1717523002366 @@ -1835,13 +1836,12 @@ packages: requires_python: '>=3.9' - kind: conda name: openssl - version: 3.3.0 - build: h4ab18f5_3 - build_number: 3 + version: 3.3.1 + build: h4ab18f5_0 subdir: linux-64 - url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda - sha256: 33dcea0ed3a61b2de6b66661cdd55278640eb99d676cd129fbff3e53641fa125 - md5: 12ea6d0d4ed54530eaed18e4835c1f7c + url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda + sha256: 9691f8bd6394c5bb0b8d2f47cd1467b91bd5b1df923b69e6b517f54496ee4b50 + md5: a41fa0e391cc9e0d6b78ac69ca047a6c depends: - ca-certificates - libgcc-ng >=12 @@ -1850,8 +1850,8 @@ packages: license: Apache-2.0 license_family: Apache purls: [] - size: 2891147 - timestamp: 1716468354865 + size: 2896170 + timestamp: 1717546157673 - kind: pypi name: optuna version: 3.6.1 @@ -3548,9 +3548,9 @@ packages: requires_python: '>=3.9' - kind: pypi name: xyzservices - version: 2024.4.0 - url: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl - sha256: b83e48c5b776c9969fffcfff57b03d02b1b1cd6607a9d9c4e7f568b01ef47f4c + version: 2024.6.0 + url: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl + sha256: fecb2508f0f2b71c819aecf5df2c03cef001c56a4b49302e640f3b34710d25e4 requires_python: '>=3.8' - kind: conda name: xz diff --git a/pyproject.toml b/pyproject.toml index 85697550..c8964a25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "holobench" -version = "1.23.0" +version = "1.24.0" authors = [{ name = "Austin Gregg-Smith", email = "blooop@gmail.com" }] description = "A package for benchmarking the performance of arbitrary functions" @@ -80,8 +80,11 @@ test = "pytest" coverage = "coverage run -m pytest && coverage xml -o coverage.xml" coverage-report = "coverage report -m" update-lock = "pixi update && git commit -a -m'update pixi.lock'" +fix = { depends_on = ["update-lock", "format", "ruff-lint"] } push = "git push" update-lock-push = { depends_on = ["update-lock", "push"] } +fix-commit-push = { depends_on = ["fix", "commit-format", "update-lock-push"] } + ci-no-cover = { depends_on = ["style", "test"] } ci = { depends_on = [ "format", diff --git a/test/test_bench_examples.py b/test/test_bench_examples.py index 5916470e..a6221836 100644 --- a/test/test_bench_examples.py +++ b/test/test_bench_examples.py @@ -9,6 +9,7 @@ from bencher.example.example_float3D import example_floats3D from bencher.example.example_custom_sweep import example_custom_sweep +from bencher.example.example_custom_sweep2 import example_custom_sweep2 from bencher.example.example_workflow import example_floats2D_workflow, example_floats3D_workflow from bencher.example.example_holosweep import example_holosweep from bencher.example.example_holosweep_tap import example_holosweep_tap @@ -85,6 +86,9 @@ def test_example_float3D(self) -> None: def test_example_custom_sweep(self) -> None: self.examples_asserts(example_custom_sweep(self.create_run_cfg())) + def test_example_custom2(self) -> None: + self.examples_asserts(example_custom_sweep2(self.create_run_cfg())) + def test_example_floats2D_workflow(self) -> None: self.examples_asserts(example_floats2D_workflow(self.create_run_cfg()))