From 373fe81b4e79a3faadf2efd6d820ac8771299c25 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 28 Aug 2024 11:38:50 -0700 Subject: [PATCH] Provide more informative warning messages in `InputDataWarning` (#2713) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2713 X-link: https://github.com/pytorch/botorch/pull/2489 Provide more informative warning messages when `InputDataWarning`s are raised to specify whether it pertains to input features or output targets. Updated unit tests accordingly to ensure coverage. Reviewed By: Balandat Differential Revision: D61797434 fbshipit-source-id: 494e5ce0e4713796f1836406d2a89c7138dba667 --- ax/modelbridge/base.py | 2 +- ax/modelbridge/cross_validation.py | 2 +- ax/modelbridge/tests/test_base_modelbridge.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 7eb50551e28..69084cff682 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -904,7 +904,7 @@ def cross_validate( # users with this warning, we filter it out. warnings.filterwarnings( "ignore", - message="Data is not standardized", + message=r"Data \(outcome observations\) not standardized", category=InputDataWarning, ) cv_predictions = self._cross_validate( diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index e853316a0ae..4293ba8203e 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -184,7 +184,7 @@ def cross_validate( # To avoid confusing users with this warning, we filter it out. warnings.filterwarnings( "ignore", - message="Data is not standardized", + message=r"Data \(outcome observations\) not standardized", category=InputDataWarning, ) cv_test_predictions = model._cross_validate( diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 84eae4b30d9..f0a2ed499fd 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -225,7 +225,7 @@ def warn_and_return_mock_obs( nonlocal called called = True warnings.warn( - "Data is not standardized", + "Data (outcome observations) not standardized", InputDataWarning, stacklevel=2, )