Skip to content

Commit

Permalink
Merge fc601b8 into cca54db
Browse files Browse the repository at this point in the history
  • Loading branch information
saitcakmak authored Jul 27, 2023
2 parents cca54db + fc601b8 commit de5d0ec
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
5 changes: 0 additions & 5 deletions botorch/optim/homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,24 @@ class HomotopySchedule(ABC):
@abstractmethod
def num_steps(self) -> int:
"""Number of steps in the schedule."""
pass

@property
@abstractmethod
def value(self) -> Any:
"""Current value in the schedule."""
pass

@property
@abstractmethod
def should_stop(self) -> bool:
"""Return true if we have incremented past the end of the schedule."""
pass

@abstractmethod
def restart(self) -> None:
"""Restart the schedule to start from the beginning."""
pass

@abstractmethod
def step(self) -> None:
"""Move to solving the next problem."""
pass


class FixedHomotopySchedule(HomotopySchedule):
Expand Down
15 changes: 15 additions & 0 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch
from botorch.acquisition import PosteriorMean
from botorch.acquisition.monte_carlo import qExpectedImprovement
from botorch.models import GenericDeterministicModel
from botorch.optim.homotopy import (
FixedHomotopySchedule,
Expand Down Expand Up @@ -142,6 +143,20 @@ def test_optimize_acqf_homotopy(self):
)
self.assertEqual(candidate[0, 0], torch.tensor(1, **tkwargs))

# With q > 1.
acqf = qExpectedImprovement(model=model, best_f=0.0)
candidate, acqf_val = optimize_acqf_homotopy(
q=3,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
fixed_features=fixed_features,
)
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))

def test_prune_candidates(self):
tkwargs = {"device": self.device, "dtype": torch.double}
# no pruning
Expand Down

0 comments on commit de5d0ec

Please sign in to comment.