Skip to content

Commit

Permalink
lint with ufmt
Browse files Browse the repository at this point in the history
  • Loading branch information
CompRhys committed Oct 21, 2024
1 parent b3e8aaf commit 01db715
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
14 changes: 6 additions & 8 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,19 @@

from __future__ import annotations

from collections.abc import Callable

from typing import Any

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.generation.gen import TGenCandidates
from botorch.optim.homotopy import Homotopy
from botorch.optim.initializers import TGenInitialConditions
from botorch.optim.optimize import optimize_acqf
from torch import Tensor

from collections.abc import Callable

from botorch.generation.gen import TGenCandidates
from botorch.optim.initializers import (
TGenInitialConditions,
)


def prune_candidates(
candidates: Tensor, acq_values: Tensor, prune_tolerance: float
Expand Down Expand Up @@ -188,7 +186,7 @@ def optimize_acqf_homotopy(
q=1,
options=options,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs
**shared_optimize_acqf_kwargs,
)
homotopy.step()

Expand Down
12 changes: 7 additions & 5 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,13 @@ def test_optimize_acqf_homotopy(self):
self.assertEqual(acqf_val.shape, torch.Size([3]))

# with linear constraints
constraints = [( # X[..., 0] + X[..., 1] >= 2.
torch.tensor([0, 1], device=self.device),
torch.ones(2, device=self.device, dtype=torch.double),
2.0,
)]
constraints = [
( # X[..., 0] + X[..., 1] >= 2.
torch.tensor([0, 1], device=self.device),
torch.ones(2, device=self.device, dtype=torch.double),
2.0,
)
]

acqf = PosteriorMean(model=model)
candidate, acqf_val = optimize_acqf_homotopy(
Expand Down

0 comments on commit 01db715

Please sign in to comment.