Skip to content

Commit

Permalink
Silence input normalization warnings in cross validation (#2310)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2310

If we compute CV directly using `TorchModelBridge._cross_validate`, we skip the part of `ModelBridge.cross_validate` that silences these warnings. Replicated the same logic for CV computations in the untransformed space.

Reviewed By: sunnyshen321

Differential Revision: D55648398

fbshipit-source-id: 3a6f7a9ac29e21bcf8413a835dbd04ef3e117d7b
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 2, 2024
1 parent e956132 commit 0014c75
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 11 deletions.
21 changes: 16 additions & 5 deletions ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from copy import deepcopy
Expand Down Expand Up @@ -39,6 +40,7 @@
ModelFitMetricProtocol,
std_of_the_standardized_error,
)
from botorch.exceptions.warnings import InputDataWarning

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -170,11 +172,20 @@ def cross_validate(
) = model._transform_inputs_for_cv(
cv_training_data=cv_training_data, cv_test_points=cv_test_points
)
cv_test_predictions = model._cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
)
with warnings.catch_warnings():
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the standardization test.
# To avoid confusing users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
category=InputDataWarning,
)
cv_test_predictions = model._cross_validate(
search_space=search_space,
cv_training_data=cv_training_data,
cv_test_points=cv_test_points,
)
# Get test observations in transformed space
cv_test_data = deepcopy(cv_test_data)
for t in model.transforms.values():
Expand Down
8 changes: 2 additions & 6 deletions ax/modelbridge/tests/test_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)

from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.parameter import FixedParameter, ParameterType
from ax.core.search_space import SearchSpace
Expand Down Expand Up @@ -205,9 +204,7 @@ def test_CrossValidate(self) -> None:

# Test selector

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def test_selector(obs):
def test_selector(obs: Observation) -> bool:
return obs.features.parameters["x"] != 4.0

result = cross_validate(model=ma, folds=-1, test_selector=test_selector)
Expand Down Expand Up @@ -269,10 +266,9 @@ def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> Non
with self.assertRaisesRegex(ValueError, "no training data"):
cross_validate(model=sobol)

# pyre-fixme[3]: Return type must be annotated.
def test_cross_validate_raises_not_implemented_error_for_non_cv_model_with_data(
self,
):
) -> None:
exp = get_branin_experiment(with_batch=True)
exp.trials[0].run().complete()
sobol = Models.SOBOL(
Expand Down
13 changes: 13 additions & 0 deletions ax/modelbridge/tests/test_model_fit_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import warnings
from typing import cast, Dict

from ax.core.experiment import Experiment
Expand Down Expand Up @@ -90,3 +91,15 @@ def test_model_fit_metrics(self) -> None:
)
self.assertIsInstance(empty_metrics, dict)
self.assertTrue(len(empty_metrics) == 0)

# testing log filtering
with warnings.catch_warnings(record=True) as ws:
fit_metrics = compute_model_fit_metrics_from_modelbridge(
model_bridge=model_bridge,
experiment=self.branin_experiment,
untransform=False,
generalization=True,
)
self.assertFalse(
any("Input data is not standardized" in str(w.message) for w in ws)
)

0 comments on commit 0014c75

Please sign in to comment.