Skip to content

Commit

Permalink
Fix warning filters in TestCase, introduce AxParameterWarning (#2389)
Browse files Browse the repository at this point in the history
Summary:
These filters were not filtering out warning as expected, due to some regex issues and unnecessary `warnings.resetwarnings()` within unit tests. `warnings.catch_warnings` resets the filters by default, so such calls were unnecessary.

Note: The BoTorch dtype warnings are produced both as `InputDataWarning` & `UserWarning`. This is only filtering for `InputDataWarning`. We can reduce the warnings further by updating it on BoTorch to use `InputDataWarning`.

Pull Request resolved: #2389

Test Plan: Ran some tests locally and checked what warnings are produced. `pytest -ra ax/core` now produces ~250 warnings, down from ~700. `test_fully_bayesian` now produces 68 warnings, down from 576.

Reviewed By: mgarrard

Differential Revision: D56443894

Pulled By: saitcakmak

fbshipit-source-id: 3bd5f694e1de843288b269223156ce3b90cf3091
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 26, 2024
1 parent 5410ee2 commit 0ea8ef8
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 80 deletions.
8 changes: 5 additions & 3 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from warnings import warn

from ax.core.types import TNumeric, TParamValue, TParamValueList
from ax.exceptions.core import AxWarning, UserInputError
from ax.exceptions.core import AxParameterWarning, UserInputError
from ax.utils.common.base import SortableBase
from ax.utils.common.typeutils import not_none
from pyre_extensions import assert_is_instance
Expand Down Expand Up @@ -566,7 +566,7 @@ def __init__(
warn(
f"Duplicate values found for ChoiceParameter {name}. "
"Initializing the parameter with duplicate values removed. ",
AxWarning,
AxParameterWarning,
stacklevel=2,
)
values = list(dict_values)
Expand Down Expand Up @@ -604,7 +604,9 @@ def _get_default_bool_and_warn(self, param_string: str) -> bool:
f"Defaulting to `{default_bool}` for parameters of `ParameterType` "
f"{self.parameter_type.name}. To override this behavior (or avoid this "
f"warning), specify `{param_string}` during `ChoiceParameter` "
"construction."
"construction.",
AxParameterWarning,
stacklevel=3,
)
return default_bool

Expand Down
18 changes: 5 additions & 13 deletions ax/core/tests/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# pyre-strict

import warnings

from ax.core.metric import Metric
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -53,20 +51,14 @@ def test_Init(self) -> None:
metrics=[self.metrics["m1"], self.metrics["m2"]],
minimize=False,
)
warnings.resetwarnings()
warnings.simplefilter("always", append=True)
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(
DeprecationWarning, "Defaulting to `minimize=False`"
):
Objective(metric=self.metrics["m1"])
self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
self.assertTrue(
any("Defaulting to `minimize=False`" in str(w.message) for w in ws)
)
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(UserWarning, "Attempting to maximize"):
Objective(Metric(name="m4", lower_is_better=True), minimize=False)
self.assertTrue(any("Attempting to maximize" in str(w.message) for w in ws))
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(UserWarning, "Attempting to minimize"):
Objective(Metric(name="m4", lower_is_better=False), minimize=True)
self.assertTrue(any("Attempting to minimize" in str(w.message) for w in ws))
self.assertEqual(
self.objective.get_unconstrainable_metrics(), [self.metrics["m1"]]
)
Expand Down
20 changes: 3 additions & 17 deletions ax/exceptions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def __str__(self) -> str:
class UserInputError(AxError):
"""Raised when the user passes in an invalid input"""

pass


class UnsupportedError(AxError):
"""Raised when an unsupported request is made.
Expand All @@ -42,8 +40,6 @@ class UnsupportedError(AxError):
It should not be used for TODO (another common use case of NIE).
"""

pass


class UnsupportedPlotError(AxError):
"""Raised when plotting functionality is not supported for the
Expand Down Expand Up @@ -85,8 +81,6 @@ class NoDataError(AxError):
Useful to distinguish data failure reasons in automated analyses.
"""

pass


class DataRequiredError(AxError):
"""Raised when more observed data is needed by the model to continue the
Expand All @@ -96,14 +90,10 @@ class DataRequiredError(AxError):
more data is available.
"""

pass


class MisconfiguredExperiment(AxError):
"""Raised when experiment has incomplete or incorrect information."""

pass


class OptimizationComplete(AxError):
"""Raised when you hit SearchSpaceExhausted and GenerationStrategyComplete."""
Expand Down Expand Up @@ -135,14 +125,10 @@ class ObjectNotFoundError(AxError, ValueError):
may be removed in the future.
"""

pass


class ExperimentNotFoundError(ObjectNotFoundError):
"""Raised when an experiment is not found in the database."""

pass


class SearchSpaceExhausted(OptimizationComplete):
"""Raised when using an algorithm that deduplicates points and no more
Expand All @@ -158,8 +144,6 @@ def __init__(self, message: str) -> None:
class IncompatibleDependencyVersion(AxError):
"""Raise when an imcompatible dependency version is installed."""

pass


class AxWarning(Warning):
"""Base Ax warning.
Expand All @@ -181,4 +165,6 @@ def __str__(self) -> str:
class AxStorageWarning(AxWarning):
"""Ax warning used for storage related concerns."""

pass

class AxParameterWarning(AxWarning):
"""Ax warning used for concerns related to parameter setups."""
17 changes: 7 additions & 10 deletions ax/modelbridge/transforms/tests/test_winsorize_legacy_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# pyre-strict

import warnings
from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -108,20 +107,18 @@ def setUp(self) -> None:
)

def test_PrintDeprecationWarning(self) -> None:
warnings.simplefilter("always", DeprecationWarning)
with warnings.catch_warnings(record=True) as ws:
expected_warning = (
"Winsorization received an out-of-date `transform_config`, containing "
"the following deprecated keys: {'winsorization_upper'}. Please "
"update the config according to the docs of "
"`ax.modelbridge.transforms.winsorize.Winsorize`."
)
with self.assertWarnsRegex(DeprecationWarning, expected_warning):
Winsorize(
search_space=None,
observations=deepcopy(self.observations),
config={"winsorization_upper": 0.2},
)
self.assertTrue(
"Winsorization received an out-of-date `transform_config`, containing "
"the following deprecated keys: {'winsorization_upper'}. Please "
"update the config according to the docs of "
"`ax.modelbridge.transforms.winsorize.Winsorize`."
in [str(w.message) for w in ws]
)

def test_Init(self) -> None:
self.assertEqual(self.t.cutoffs["m1"], (-float("inf"), 2.0))
Expand Down
44 changes: 20 additions & 24 deletions ax/modelbridge/transforms/tests/test_winsorize_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,17 @@ def setUp(self) -> None:
)

def test_PrintDeprecationWarning(self) -> None:
warnings.simplefilter("always", DeprecationWarning)
with warnings.catch_warnings(record=True) as ws:
expected_warning = (
"Winsorization received an out-of-date `transform_config`, containing "
'the key `"optimization_config"`. Please update the config according '
"to the docs of `ax.modelbridge.transforms.winsorize.Winsorize`."
)
with self.assertWarnsRegex(DeprecationWarning, expected_warning):
Winsorize(
search_space=None,
observations=deepcopy(self.observations),
config={"optimization_config": "dummy_val"},
)
self.assertTrue(
"Winsorization received an out-of-date `transform_config`, containing "
'the key `"optimization_config"`. Please update the config according '
"to the docs of `ax.modelbridge.transforms.winsorize.Winsorize`."
in [str(w.message) for w in ws]
)

def test_Init(self) -> None:
self.assertEqual(self.t.cutoffs["m1"], (-INF, 2.0))
Expand Down Expand Up @@ -491,37 +489,35 @@ def test_winsorization_with_optimization_config(self) -> None:
metrics=[m1, m3], op=ComparisonOp.GEQ, bound=3, relative=False
)
]
warnings.simplefilter("always", append=True)
with warnings.catch_warnings(record=True) as ws:
transform = get_transform(
observation_data=deepcopy(all_obsd),
optimization_config=optimization_config,
)
for i in range(2):
self.assertTrue(
"Automatic winsorization isn't supported for a "
"`ScalarizedOutcomeConstraint`. Specify the winsorization settings "
f"manually if you want to winsorize metric m{['1', '3'][i]}."
in [str(w.message) for w in ws]
)
for i in range(2):
self.assertTrue(
"Automatic winsorization isn't supported for a "
"`ScalarizedOutcomeConstraint`. Specify the winsorization settings "
f"manually if you want to winsorize metric m{['1', '3'][i]}."
in [str(w.message) for w in ws]
)
# Multi-objective without objective thresholds should warn and winsorize
moo_objective = MultiObjective(
[Objective(m1, minimize=False), Objective(m2, minimize=True)]
)
optimization_config = MultiObjectiveOptimizationConfig(objective=moo_objective)
warnings.simplefilter("always", append=True)
with warnings.catch_warnings(record=True) as ws:
transform = get_transform(
observation_data=deepcopy(all_obsd),
optimization_config=optimization_config,
)
for _ in range(2):
self.assertTrue(
"Encountered a `MultiObjective` without objective thresholds. We "
"will winsorize each objective separately. We strongly recommend "
"specifying the objective thresholds when using multi-objective "
"optimization." in [str(w.message) for w in ws]
)
for _ in range(2):
self.assertTrue(
"Encountered a `MultiObjective` without objective thresholds. We "
"will winsorize each objective separately. We strongly recommend "
"specifying the objective thresholds when using multi-objective "
"optimization." in [str(w.message) for w in ws]
)
self.assertEqual(transform.cutoffs["m1"], (-6.5, INF))
self.assertEqual(transform.cutoffs["m2"], (-INF, 10.0))
self.assertEqual(transform.cutoffs["m3"], (-INF, INF))
Expand Down
12 changes: 3 additions & 9 deletions ax/models/tests/test_fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
# pyre-strict

import dataclasses
import warnings
from contextlib import ExitStack
from itertools import count, product
from logging import Logger
Expand Down Expand Up @@ -101,15 +100,10 @@ def test_FullyBayesianBotorchModel(
self, dtype: torch.dtype = torch.float, cuda: bool = False
) -> None:
# test deprecation warning
warnings.resetwarnings() # this is necessary for building in mode/opt
warnings.simplefilter("always", append=True)
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(
DeprecationWarning, "Passing `use_saas` is no longer supported"
):
self.model_cls(use_saas=True)
self.assertTrue(
any(issubclass(w.category, DeprecationWarning) for w in ws)
)
msg = "Passing `use_saas` is no longer supported"
self.assertTrue(any(msg in str(w.message) for w in ws))
Xs1, Ys1, Yvars1, bounds, tfs, fns, mns = get_torch_test_data(
dtype=dtype, cuda=cuda, constant_noise=True
)
Expand Down
9 changes: 5 additions & 4 deletions ax/utils/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@

import numpy as np
import yappi
from ax.exceptions.core import AxParameterWarning
from ax.utils.common.base import Base
from ax.utils.common.equality import object_attribute_dicts_find_unequal_fields
from ax.utils.common.logger import get_logger
Expand Down Expand Up @@ -334,19 +335,19 @@ def setUp(self) -> None:
# Choice parameter default parameter type / is_ordered warnings.
warnings.filterwarnings(
"ignore",
message="is not specified for `ChoiceParameter`",
category=UserWarning,
message=".*is not specified for .ChoiceParameter.*",
category=AxParameterWarning,
)
# BoTorch float32 warning.
warnings.filterwarnings(
"ignore",
message="The model inputs are of type",
category=UserWarning,
category=InputDataWarning,
)
# BoTorch input standardization warnings.
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
message="Input data is not",
category=InputDataWarning,
)

Expand Down

0 comments on commit 0ea8ef8

Please sign in to comment.