Skip to content

Commit

Permalink
Add an option to pass model_cv_kwargs to ModelSpec.cross_validate (#2566
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #2566

This allows us to customize the kwargs passed to `cross_validate` during call time (in addition to `ModelSpec.model_cv_kwargs`, which is specified at initialization). This can be used in `SingleDiagnosticBestModelSelector` to customize how the CV diagnostics are computed.

To ensure we don't return cached results that were computed using different kwargs, we also store the last `cv_kwargs` used and re-compute CV if they changed.

Reviewed By: mgarrard

Differential Revision: D59406177

fbshipit-source-id: 577a6965aaa362d301c152ac7b94bf4fe7326541
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jul 9, 2024
1 parent 5862f7f commit faf25c4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
36 changes: 28 additions & 8 deletions ax/modelbridge/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,15 @@ class ModelSpec(SortableBase, SerializationMixin):
# on `ModelSpec.fit`
_fitted_model: Optional[ModelBridge] = None

# stored cross validation results set in cross validate
# Stored cross validation results set in cross validate.
_cv_results: Optional[List[CVResult]] = None

# stored cross validation diagnostics set in cross validate
# Stored cross validation diagnostics set in cross validate.
_diagnostics: Optional[CVDiagnostics] = None

# Stored to check if the CV result & diagnostic cache is safe to reuse.
_last_cv_kwargs: Optional[Dict[str, Any]] = None

# Stored to check if the model can be safely updated in fit.
_last_fit_arg_ids: Optional[Dict[str, int]] = None

Expand Down Expand Up @@ -151,26 +154,43 @@ def fit(

def cross_validate(
self,
model_cv_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[List[CVResult]], Optional[CVDiagnostics]]:
"""
Call cross_validate, compute_diagnostics and cache the results
If the model cannot be cross validated, warn and return None
Call cross_validate, compute_diagnostics and cache the results.
If the model cannot be cross validated, warn and return None.
NOTE: If there are cached results, and the cache was computed using
the same kwargs, this will return the cached results.
Args:
model_cv_kwargs: Optional kwargs to pass into `cross_validate` call.
These are combined with `self.model_cv_kwargs`, with the
`model_cv_kwargs` taking precedence over `self.model_cv_kwargs`.
Returns:
A tuple of CV results (observed vs predicted values) and the
corresponding diagnostics.
"""
if self._cv_results is not None and self._diagnostics is not None:
cv_kwargs = {**self.model_cv_kwargs, **(model_cv_kwargs or {})}
if (
self._cv_results is not None
and self._diagnostics is not None
and cv_kwargs == self._last_cv_kwargs
):
return self._cv_results, self._diagnostics

self._assert_fitted()
try:
self._cv_results = cross_validate(
model=self.fitted_model, **self.model_cv_kwargs
)
self._cv_results = cross_validate(model=self.fitted_model, **cv_kwargs)
except NotImplementedError:
warnings.warn(
f"{self.model_enum.value} cannot be cross validated", stacklevel=2
)
return None, None

self._diagnostics = compute_diagnostics(self._cv_results)
self._last_cv_kwargs = cv_kwargs
return self._cv_results, self._diagnostics

@property
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/tests/test_best_model_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def setUp(self) -> None:
ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR)
ms._cv_results = Mock()
ms._diagnostics = diagnostics
ms._last_cv_kwargs = {}
self.model_specs.append(ms)

def test_user_input_error(self) -> None:
Expand Down
14 changes: 14 additions & 0 deletions ax/modelbridge/tests/test_model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_cross_validate_with_GP_model(
self.assertIsNotNone(cv_diagnostics)
mock_cv.assert_not_called()
mock_diagnostics.assert_not_called()
self.assertEqual(ms._last_cv_kwargs, {"test_key": "test-value"})

with self.subTest("fit clears the CV cache"):
mock_cv.reset_mock()
Expand All @@ -117,6 +118,19 @@ def test_cross_validate_with_GP_model(
mock_cv.assert_called_with(model=fake_mb, test_key="test-value")
mock_diagnostics.assert_called_with(["fake-cv-result"])

with self.subTest("pass in optional kwargs"):
mock_cv.reset_mock()
mock_diagnostics.reset_mock()
# Cache is not empty, but CV will be called since there are new kwargs.
assert ms._cv_results is not None

cv_results, cv_diagnostics = ms.cross_validate(model_cv_kwargs={"test": 1})

self.assertIsNotNone(cv_results)
self.assertIsNotNone(cv_diagnostics)
mock_cv.assert_called_with(model=fake_mb, test_key="test-value", test=1)
self.assertEqual(ms._last_cv_kwargs, {"test": 1, "test_key": "test-value"})

@patch(f"{ModelSpec.__module__}.compute_diagnostics")
@patch(f"{ModelSpec.__module__}.cross_validate", side_effect=NotImplementedError)
def test_cross_validate_with_non_GP_model(
Expand Down

0 comments on commit faf25c4

Please sign in to comment.