Skip to content

Commit

Permalink
Add GPBandit multiobjective as default by using from_problem factory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 583067422
  • Loading branch information
qiuyiz authored and copybara-github committed Nov 16, 2023
1 parent a94eb5a commit a6bcfc9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
5 changes: 4 additions & 1 deletion vizier/_src/service/policy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,17 @@ def __call__(
study_name: str,
) -> pythia.Policy:
del study_name
# TODO: Compare string to Algorithm enums.
if algorithm in (
'DEFAULT',
'ALGORITHM_UNSPECIFIED',
'GAUSSIAN_PROCESS_BANDIT',
):
from vizier._src.algorithms.designers import gp_bandit

return dp.DesignerPolicy(policy_supporter, gp_bandit.VizierGPBandit)
return dp.DesignerPolicy(
policy_supporter, gp_bandit.VizierGPBandit.from_problem
)
elif algorithm == 'RANDOM_SEARCH':
from vizier._src.algorithms.policies import random_policy

Expand Down
22 changes: 21 additions & 1 deletion vizier/_src/service/vizier_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,32 @@ def test_get_suggestions(self):
dict(algorithm=pyvizier.Algorithm.QUASI_RANDOM_SEARCH),
dict(algorithm=pyvizier.Algorithm.GRID_SEARCH),
dict(algorithm=pyvizier.Algorithm.NSGA2, multi_objective=True),
dict(
algorithm=pyvizier.Algorithm.ALGORITHM_UNSPECIFIED,
multi_objective=True,
),
dict(
algorithm='DEFAULT',
multi_objective=True,
),
dict(
algorithm='DEFAULT',
multi_objective=False,
),
dict(
algorithm=pyvizier.Algorithm.GAUSSIAN_PROCESS_BANDIT,
multi_objective=True,
),
dict(
algorithm=pyvizier.Algorithm.GAUSSIAN_PROCESS_BANDIT,
multi_objective=False,
),
)
def test_e2e_tuning(
self,
*,
algorithm,
num_iterations: int = 50,
num_iterations: int = 10,
batch_size: int = 1,
multi_objective: bool = False,
):
Expand Down

0 comments on commit a6bcfc9

Please sign in to comment.