Skip to content

Commit

Permalink
Improve typing of *_kwargs fields in ModelSpec (#2565)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2565

These were typed as `Optional[Dict[str, Any]]` but immediately made into `Dict[str, Any]` in `__post_init__`. The internal usage also included a bunch of `self.*_kwargs or {}`, presumably to make Pyre happy with the optional field.

This diff updates the type-hints to `Dict[str, Any]` and keeps these arguments as optional during initialization using `field(default_factory=dict)` as the default (which will assign an empty dict as the default). The change is also extended to `GenerationStep`.

I kept the `__post_init__` method in place to keep backwards compatibility with any previous usage of `None`.

Reviewed By: Balandat

Differential Revision: D59403745

fbshipit-source-id: 5f26ea20d339559411006a82f73e72ad36e68ff0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jul 9, 2024
1 parent 2428521 commit 5862f7f
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 45 deletions.
2 changes: 1 addition & 1 deletion ax/benchmark/methods/modular_botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_sobol_botorch_modular_acquisition(
model=Models.BOTORCH_MODULAR,
num_trials=-1,
model_kwargs=model_kwargs,
model_gen_kwargs=model_gen_kwargs,
model_gen_kwargs=model_gen_kwargs or {},
),
],
)
Expand Down
3 changes: 1 addition & 2 deletions ax/modelbridge/dispatch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,7 @@ def _make_botorch_step(
min_trials_observed=min_trials_observed or ceil(num_trials / 2),
enforce_num_trials=enforce_num_trials,
max_parallelism=max_parallelism,
# `model_kwargs` should default to `None` if empty
model_kwargs=model_kwargs if len(model_kwargs) > 0 else None,
model_kwargs=model_kwargs,
should_deduplicate=should_deduplicate,
)

Expand Down
9 changes: 7 additions & 2 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ class GenerationStep(GenerationNode, SortableBase):

# Optional model specifications:
# Kwargs to pass into the Models constructor (or factory function).
model_kwargs: Optional[Dict[str, Any]] = None
model_kwargs: Dict[str, Any] = field(default_factory=dict)
# Kwargs to pass into the Model's `.gen` function.
model_gen_kwargs: Optional[Dict[str, Any]] = None
model_gen_kwargs: Dict[str, Any] = field(default_factory=dict)

# Optional specifications for use in generation strategy:
completion_criteria: Sequence[TransitionCriterion] = field(default_factory=list)
Expand Down Expand Up @@ -651,6 +651,11 @@ def __post_init__(self) -> None:
f"{self.num_trials}`), making completion of this step impossible. "
"Please alter inputs so that `min_trials_observed <= num_trials`."
)
# For backwards compatibility with None / Optional input.
self.model_kwargs = self.model_kwargs if self.model_kwargs is not None else {}
self.model_gen_kwargs = (
self.model_gen_kwargs if self.model_gen_kwargs is not None else {}
)
if not isinstance(self.model, ModelRegistryBase):
if not callable(self.model):
raise UserInputError(
Expand Down
40 changes: 16 additions & 24 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import json
import warnings
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple

from ax.core.data import Data
Expand Down Expand Up @@ -56,11 +56,11 @@ class ModelSpec(SortableBase, SerializationMixin):
model_enum: ModelRegistryBase
# Kwargs to pass into the `Model` + `ModelBridge` constructors in
# `ModelRegistryBase.__call__`.
model_kwargs: Optional[Dict[str, Any]] = None
model_kwargs: Dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `ModelBridge.gen`.
model_gen_kwargs: Optional[Dict[str, Any]] = None
model_gen_kwargs: Dict[str, Any] = field(default_factory=dict)
# Kwargs to pass to `cross_validate`.
model_cv_kwargs: Optional[Dict[str, Any]] = None
model_cv_kwargs: Dict[str, Any] = field(default_factory=dict)

# Fitted model, constructed using specified `model_kwargs` and `Data`
# on `ModelSpec.fit`
Expand Down Expand Up @@ -91,19 +91,13 @@ def fixed_features(self) -> Optional[ObservationFeatures]:
"""
Fixed generation features to pass into the Model's `.gen` function.
"""
return (
self.model_gen_kwargs.get("fixed_features")
if self.model_gen_kwargs is not None
else None
)
return self.model_gen_kwargs.get("fixed_features", None)

@fixed_features.setter
def fixed_features(self, value: Optional[ObservationFeatures]) -> None:
"""
Fixed generation features to pass into the Model's `.gen` function.
"""
if self.model_gen_kwargs is None:
self.model_gen_kwargs = {}
self.model_gen_kwargs["fixed_features"] = value

@property
Expand Down Expand Up @@ -131,7 +125,7 @@ def fit(
# NOTE: It's important to copy `self.model_kwargs` here to avoid actually
# adding contents of `model_kwargs` passed to this method, to
# `self.model_kwargs`.
combined_model_kwargs = {**(self.model_kwargs or {}), **model_kwargs}
combined_model_kwargs = {**self.model_kwargs, **model_kwargs}
if self._fitted_model is not None and self._safe_to_update(
experiment=experiment, combined_model_kwargs=combined_model_kwargs
):
Expand Down Expand Up @@ -168,11 +162,12 @@ def cross_validate(
self._assert_fitted()
try:
self._cv_results = cross_validate(
model=self.fitted_model,
**(self.model_cv_kwargs or {}),
model=self.fitted_model, **self.model_cv_kwargs
)
except NotImplementedError:
warnings.warn(f"{self.model_enum.value} cannot be cross validated")
warnings.warn(
f"{self.model_enum.value} cannot be cross validated", stacklevel=2
)
return None, None

self._diagnostics = compute_diagnostics(self._cv_results)
Expand Down Expand Up @@ -213,15 +208,10 @@ def gen(self, **model_gen_kwargs: Any) -> GeneratorRun:
"""
fitted_model = self.fitted_model
model_gen_kwargs = consolidate_kwargs(
kwargs_iterable=[
self.model_gen_kwargs,
model_gen_kwargs,
],
kwargs_iterable=[self.model_gen_kwargs, model_gen_kwargs],
keywords=get_function_argument_names(fitted_model.gen),
)
generator_run = fitted_model.gen(
**model_gen_kwargs,
)
generator_run = fitted_model.gen(**model_gen_kwargs)
fit_and_std_quality_and_generalization_dict = (
get_fit_and_std_quality_and_generalization_dict(
fitted_model_bridge=self.fitted_model,
Expand Down Expand Up @@ -320,6 +310,7 @@ class FactoryFunctionModelSpec(ModelSpec):
model_enum: Optional[ModelRegistryBase] = None

def __post_init__(self) -> None:
super().__post_init__()
if self.model_enum is not None:
raise UserInputError(
"Use regular `ModelSpec` when it's possible to describe the "
Expand All @@ -333,7 +324,8 @@ def __post_init__(self) -> None:
)
warnings.warn(
"Using a factory function to describe the model, so optimization state "
"cannot be stored and optimization is not resumable if interrupted."
"cannot be stored and optimization is not resumable if interrupted.",
stacklevel=3,
)

@property
Expand All @@ -360,7 +352,7 @@ def fit(
kwargs to this function (local kwargs take precedent)
"""
factory_function = not_none(self.factory_function)
all_kwargs = deepcopy((self.model_kwargs or {}))
all_kwargs = deepcopy(self.model_kwargs)
all_kwargs.update(model_kwargs)
self._fitted_model = factory_function(
# Factory functions do not have a unified signature; e.g. some factory
Expand Down
9 changes: 1 addition & 8 deletions ax/modelbridge/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def setUp(self) -> None:
# `Union[typing.Callable[..., ModelBridge], ModelRegistryBase]`.
model_enum=self.sobol_generation_step.model,
model_kwargs=self.model_kwargs,
model_gen_kwargs=None,
)

def test_init(self) -> None:
Expand Down Expand Up @@ -251,13 +250,7 @@ def test_init_factory_function(self) -> None:
generation_step = GenerationStep(model=get_sobol, num_trials=-1)
self.assertEqual(
generation_step.model_specs,
[
FactoryFunctionModelSpec(
factory_function=get_sobol,
model_kwargs=None,
model_gen_kwargs=None,
)
],
[FactoryFunctionModelSpec(factory_function=get_sobol)],
)

def test_properties(self) -> None:
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def test_fixed_features(self) -> None:
new_features = ObservationFeatures(parameters={"a": 1.0})
ms.fixed_features = new_features
self.assertEqual(ms.fixed_features, new_features)
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
self.assertEqual(ms.model_gen_kwargs["fixed_features"], new_features)

def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> None:
Expand Down
3 changes: 1 addition & 2 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2816,7 +2816,6 @@ def test_torch_device(self) -> None:
)
ax_client = get_branin_optimization(torch_device=device)
gpei_step_kwargs = ax_client.generation_strategy._steps[1].model_kwargs
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
self.assertEqual(gpei_step_kwargs["torch_device"], device)

def test_repr_function(
Expand Down Expand Up @@ -3019,5 +3018,5 @@ def _attach_not_completed_trials(ax_client) -> None:
# Test metric evaluation method
# pyre-fixme[2]: Parameter must be annotated.
def _evaluate_test_metrics(parameters) -> Dict[str, Tuple[float, float]]:
x = np.array([parameters.get(f"x{i+1}") for i in range(2)])
x = np.array([parameters.get(f"x{i + 1}") for i in range(2)])
return {"test_metric1": (x[0] / x[1], 0.0), "test_metric2": (x[0] + x[1], 0.0)}
8 changes: 4 additions & 4 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def generation_step_from_json(
),
)
if kwargs
else None
else {}
),
model_gen_kwargs=(
_decode_callables_from_references(
Expand All @@ -731,7 +731,7 @@ def generation_step_from_json(
),
)
if gen_kwargs
else None
else {}
),
index=generation_step_json.pop("index", -1),
should_deduplicate=generation_step_json.pop("should_deduplicate", False),
Expand Down Expand Up @@ -763,7 +763,7 @@ def model_spec_from_json(
),
)
if kwargs
else None
else {}
),
model_gen_kwargs=(
_decode_callables_from_references(
Expand All @@ -774,7 +774,7 @@ def model_spec_from_json(
),
)
if gen_kwargs
else None
else {}
),
)

Expand Down
2 changes: 1 addition & 1 deletion ax/utils/testing/modeling_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def get_generation_strategy(
search_space=get_search_space(), should_deduplicate=True
)
if with_callable_model_kwarg:
# pyre-ignore[16]: testing hack to test serialization of callable kwargs
# Testing hack to test serialization of callable kwargs
# in generation steps.
gs._steps[0].model_kwargs["model_constructor"] = get_sobol
if with_experiment:
Expand Down

0 comments on commit 5862f7f

Please sign in to comment.