From 6ed79633c1058be9ef674cc64053a7fb16d50914 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Wed, 9 Aug 2023 14:46:55 -0700 Subject: [PATCH] Don't allow unused **kwargs in input_constructors except for a defined set of exceptions (#1772) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1772 X-link: https://github.com/pytorch/botorch/pull/1872 [x] Remove unused arguments from input constructors and related functions. The idea is especially not to let unused keyword arguments disappear into `**kwargs` and be silently ignored [x] add arguments to some input constructors so they don't need any `**kwargs` [x] Add a decorator that ensures that each input constructor can accept a certain set of keyword arguments, even if those are not used are the constructor, while still erroring on [ ] Prevent arguments from having different defaults in the input constructors as in acquisition functions Reviewed By: SebastianAment Differential Revision: D46519588 fbshipit-source-id: 44c5c6e99a4e3ae3b287da9901a9e20af322239e --- ax/models/torch/botorch_modular/sebo.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ax/models/torch/botorch_modular/sebo.py b/ax/models/torch/botorch_modular/sebo.py index 1431a799e2f..289040158ee 100644 --- a/ax/models/torch/botorch_modular/sebo.py +++ b/ax/models/torch/botorch_modular/sebo.py @@ -61,8 +61,8 @@ def __init__( tkwargs = {"dtype": surrogate.dtype, "device": surrogate.device} options = options or {} - self.penalty_name: str = options.get("penalty", "L0_norm") - self.target_point: Tensor = options.get("target_point", None) + self.penalty_name: str = options.pop("penalty", "L0_norm") + self.target_point: Tensor = options.pop("target_point", None) if self.target_point is None: raise ValueError("please provide target point.") self.target_point.to(**tkwargs) # pyre-ignore @@ -93,6 +93,8 @@ def __init__( ) # instantiate botorch_acqf_class + if not issubclass(botorch_acqf_class, qExpectedHypervolumeImprovement): + raise ValueError("botorch_acqf_class must be qEHVI to use SEBO") super().__init__( surrogates={"sebo": surrogate_f}, search_space_digest=search_space_digest,