Skip to content

Commit

Permalink
Add basic iteration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AVHopp committed May 6, 2024
1 parent 1c5ed4a commit 5c5364c
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/test_iterations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from baybe.acquisition.base import AcquisitionFunction
from baybe.kernels import MaternKernel, ScaleKernel
from baybe.kernels.priors import (
GammaPrior,
HalfCauchyPrior,
Expand Down Expand Up @@ -128,6 +129,20 @@
SmoothedBoxPrior(0, 3, 0.1),
]

# Note that this test does not vary the priors as there is currently a separate test for
# this aspect.
valid_base_kernels = [MaternKernel(nu=nu) for nu in (0.5, 1.5, 2.5)]

# Due to numerical issues (i.e., matrix systems not being solvable) we do not test
# different priors here.
valid_scale_kernels = [
ScaleKernel(base_kernel=base_kernel, outputscale_prior=prior)
for base_kernel in valid_base_kernels
for prior in valid_priors
]

valid_kernels = valid_base_kernels + valid_scale_kernels

test_targets = [
["Target_max"],
["Target_min"],
Expand Down Expand Up @@ -169,6 +184,17 @@ def test_iter_prior(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size)


# For these tests, there were numerical issues without restricting the batch size to 1
@pytest.mark.slow
@pytest.mark.parametrize(
"kernel", valid_kernels, ids=[c.__class__ for c in valid_kernels]
)
@pytest.mark.parametrize("n_iterations", [3], ids=["i3"])
@pytest.mark.parametrize("batch_size", [1], ids=["b1"])
def test_iter_kernel(campaign, n_iterations, batch_size):
run_iterations(campaign, n_iterations, batch_size)


@pytest.mark.slow
@pytest.mark.parametrize(
"surrogate_model",
Expand Down

0 comments on commit 5c5364c

Please sign in to comment.