Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
CristianLara authored Jan 10, 2025
2 parents 9f818ef + aef25d7 commit eff8c03
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 45 deletions.
3 changes: 2 additions & 1 deletion botorch/models/likelihoods/sparse_outlier_noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,8 @@ def _optimal_rhos(self, mll: ExactMarginalLogLikelihood) -> Tensor:
mll.train() # NOTE: this changes model.train_inputs to be unnormalized.
X, Y = mll.model.train_inputs[0], mll.model.train_targets
F = mll.model(X)
L = mll.likelihood(F, X)
TX = mll.model.transform_inputs(X)
L = mll.likelihood(F, TX) # likelihood expects transformed inputs
S = L.covariance_matrix # (Kernel Matrix + Noise Matrix)

# NOTE: The following computation is mathematically equivalent to the formula
Expand Down
3 changes: 2 additions & 1 deletion botorch/models/relevance_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,8 @@ def log_prior(
mll.train()
X, Y = mll.model.train_inputs[0], mll.model.train_targets
F = mll.model(X)
mll_i = cast(Tensor, mll(F, Y, X))
TX = mll.model.transform_inputs(X) if mll.model.training else X
mll_i = cast(Tensor, mll(F, Y, TX))
log_mll_trace.append(mll_i)
support_size, log_prior_i = log_prior(
model,
Expand Down
4 changes: 3 additions & 1 deletion botorch/optim/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@


_LBFGSB_MAXITER_MAXFUN_REGEX = re.compile( # regex for maxiter and maxfun messages
"TOTAL NO. of (ITERATIONS REACHED LIMIT|f AND g EVALUATIONS EXCEEDS LIMIT)"
# Note that the messages changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) "
+ "(ITERATIONS REACHED LIMIT|(f AND g|F,G) EVALUATIONS EXCEEDS LIMIT)"
)


Expand Down
2 changes: 1 addition & 1 deletion botorch/test_functions/multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def penicillin_vectorized(cls, X_input: Tensor) -> Tensor:
F_loss = (
V[active]
* cls.lambd
* (torch.exp(5 * ((T[active] - cls.T_o) / (cls.T_v - cls.T_o))) - 1)
* torch.special.expm1(5 * ((T[active] - cls.T_o) / (cls.T_v - cls.T_o)))
)
dV_dt = F[active] - F_loss
mu = (
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ gpytorch==1.13
linear_operator==0.5.3
torch>=2.0.1
pyro-ppl>=1.8.4
scipy
scipy<1.15
multipledispatch
8 changes: 5 additions & 3 deletions test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
import warnings
from unittest import mock

Expand Down Expand Up @@ -225,13 +226,14 @@ def test_gen_candidates_scipy_with_fixed_features_inequality_constraints(self):
def test_gen_candidates_scipy_warns_opt_failure(self):
with warnings.catch_warnings(record=True) as ws:
self.test_gen_candidates(options={"maxls": 1})
expected_msg = (
expected_msg = re.compile(
# The message changed with scipy 1.15, hence the different matching here.
"Optimization failed within `scipy.optimize.minimize` with status 2"
" and message ABNORMAL_TERMINATION_IN_LNSRCH."
" and message ABNORMAL(|_TERMINATION_IN_LNSRCH)."
)
expected_warning_raised = any(
issubclass(w.category, OptimizationWarning)
and expected_msg in str(w.message)
and expected_msg.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand Down
22 changes: 22 additions & 0 deletions test/models/test_relevance_pursuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from __future__ import annotations

import itertools
import warnings

from functools import partial
from unittest.mock import patch

import gpytorch
import torch
Expand Down Expand Up @@ -197,6 +199,26 @@ def _test_robust_gp_end_to_end(
undetected_outliers = set(outlier_indices) - set(sparse_module.support)
self.assertEqual(len(undetected_outliers), 0)

with patch.object(
SparseOutlierNoise,
"forward",
wraps=sparse_module.forward,
) as sparse_module_fwd:
# testing that posterior inference on training set does not throw warnings,
# which means that the passed inputs are the equal to the cached ones.
with warnings.catch_warnings(record=True) as warnings_log:
map_model.posterior(X)
self.assertEqual(warnings_log, [])
# Testing that the noise module's forward receives transformed inputs
X_in_call = sparse_module_fwd.call_args.kwargs["X"]
self.assertIsInstance(X_in_call, list)
self.assertEqual(len(X_in_call), 1)
X_in_call = X_in_call[0]
X_max = X_in_call.amax(dim=0)
X_min = X_in_call.amin(dim=0)
self.assertAllClose(X_max, torch.ones_like(X_max))
self.assertAllClose(X_min, torch.zeros_like(X_min))

def test_robust_relevance_pursuit(self) -> None:
for optimizer, convex_parameterization, dtype in itertools.product(
[forward_relevance_pursuit, backward_relevance_pursuit],
Expand Down
2 changes: 1 addition & 1 deletion test/models/transforms/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_norm_to_lognorm(self):
mu_ln_expected = torch.tensor(
[1.0, 2.0, 3.0], device=self.device, dtype=dtype
)
var_ln_expected = (torch.exp(var) - 1) * mu_ln_expected**2
var_ln_expected = torch.special.expm1(var) * mu_ln_expected**2
self.assertAllClose(mu_ln, mu_ln_expected)
self.assertAllClose(var_ln, var_ln_expected)

Expand Down
17 changes: 13 additions & 4 deletions test/optim/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,20 @@ def _callback(parameters, result, out) -> None:
def test_post_processing(self):
closure = next(iter(self.closures.values()))
wrapper = NdarrayOptimizationClosure(closure, closure.parameters)

# Scipy changed return values and messages in v1.15, so we check both
# old and new versions here.
status_msgs = [
# scipy >=1.15
(OptimizationStatus.FAILURE, "ABNORMAL_TERMINATION_IN_LNSRCH"),
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
# scipy <1.15
(OptimizationStatus.FAILURE, "ABNORMAL "),
(OptimizationStatus.STOPPED, "TOTAL NO. OF ITERATIONS REACHED LIMIT"),
]

with patch.object(core, "minimize_with_timeout") as mock_minimize_with_timeout:
for status, msg in (
(OptimizationStatus.FAILURE, b"ABNORMAL_TERMINATION_IN_LNSRCH"),
(OptimizationStatus.STOPPED, "TOTAL NO. of ITERATIONS REACHED LIMIT"),
):
for status, msg in status_msgs:
mock_minimize_with_timeout.return_value = OptimizeResult(
x=wrapper.state,
fun=1.0,
Expand Down
13 changes: 11 additions & 2 deletions test/optim/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
from unittest.mock import MagicMock, patch
from warnings import catch_warnings

Expand All @@ -20,6 +21,11 @@
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from scipy.optimize import OptimizeResult

MAX_ITER_MSG_REGEX = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class TestFitGPyTorchMLLScipy(BotorchTestCase):
def setUp(self, suppress_input_warnings: bool = True) -> None:
Expand Down Expand Up @@ -63,15 +69,18 @@ def _test_fit_gpytorch_mll_scipy(self, mll):
)

# Test maxiter warning message
self.assertTrue(any("TOTAL NO. of" in str(w.message) for w in ws))

self.assertTrue(any(MAX_ITER_MSG_REGEX.search(str(w.message)) for w in ws))
self.assertTrue(
any(issubclass(w.category, OptimizationWarning) for w in ws)
)

# Test iteration tracking
self.assertIsInstance(result, OptimizationResult)
self.assertLessEqual(result.step, options["maxiter"])
self.assertEqual(sum(1 for w in ws if "TOTAL NO. of" in str(w.message)), 1)
self.assertEqual(
sum(1 for w in ws if MAX_ITER_MSG_REGEX.search(str(w.message))), 1
)

# Test that user provided bounds are respected
with self.subTest("bounds"), module_rollback_ctx(mll, checkpoint=ckpt):
Expand Down
55 changes: 31 additions & 24 deletions test/optim/test_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import itertools
import re
import warnings
from functools import partial
from itertools import product
Expand Down Expand Up @@ -724,19 +725,20 @@ def test_optimize_acqf_warns_on_opt_failure(self):
raw_samples=raw_samples,
batch_initial_conditions=initial_conditions,
)
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message "
"ABNORMAL_TERMINATION_IN_LNSRCH.')]\nBecause you specified "
"`batch_initial_conditions` larger than required `num_restarts`, "
"optimization will not be retried with new initial conditions and "
"will proceed with the current solution. Suggested remediation: "
"Try again with different `batch_initial_conditions`, don't provide "
"`batch_initial_conditions`, or increase `num_restarts`."
message_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message "
r"ABNORMAL(: |_TERMINATION_IN_LNSRCH).'\)]\nBecause you specified "
r"`batch_initial_conditions` larger than required `num_restarts`, "
r"optimization will not be retried with new initial conditions and "
r"will proceed with the current solution. Suggested remediation: "
r"Try again with different `batch_initial_conditions`, don't provide "
r"`batch_initial_conditions`, or increase `num_restarts`."
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand Down Expand Up @@ -774,14 +776,16 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
# more likely
options={"maxls": 2},
)
message = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
message_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
r"initial conditions."
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertTrue(expected_warning_raised)
Expand All @@ -803,7 +807,8 @@ def test_optimize_acqf_successfully_restarts_on_opt_failure(self):
retry_on_optimization_warning=False,
)
expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_regex.search(str(w.message))
for w in ws
)
self.assertFalse(expected_warning_raised)
Expand Down Expand Up @@ -840,19 +845,21 @@ def test_optimize_acqf_warns_on_second_opt_failure(self):
options={"maxls": 2},
)

message_1 = (
"Optimization failed in `gen_candidates_scipy` with the following "
"warning(s):\n[OptimizationWarning('Optimization failed within "
"`scipy.optimize.minimize` with status 2 and message ABNORMAL_TERMINATION"
"_IN_LNSRCH.')]\nTrying again with a new set of initial conditions."
message_1_regex = re.compile(
r"Optimization failed in `gen_candidates_scipy` with the following "
r"warning\(s\):\n\[OptimizationWarning\('Optimization failed within "
r"`scipy.optimize.minimize` with status 2 and message ABNORMAL(: |"
r"_TERMINATION_IN_LNSRCH).'\)\]\nTrying again with a new set of "
r"initial conditions."
)

message_2 = (
"Optimization failed on the second try, after generating a new set "
"of initial conditions."
)
first_expected_warning_raised = any(
issubclass(w.category, RuntimeWarning) and message_1 in str(w.message)
issubclass(w.category, RuntimeWarning)
and message_1_regex.search(str(w.message))
for w in ws
)
second_expected_warning_raised = any(
Expand Down
17 changes: 14 additions & 3 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import math
import re
from collections.abc import Callable, Iterable
from contextlib import ExitStack, nullcontext
from copy import deepcopy
Expand All @@ -30,7 +31,10 @@
from gpytorch.mlls import ExactMarginalLogLikelihood, VariationalELBO
from linear_operator.utils.errors import NotPSDError

MAX_ITER_MSG = "TOTAL NO. of ITERATIONS REACHED LIMIT"
MAX_ITER_MSG_REGEX = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class MockOptimizer:
Expand Down Expand Up @@ -215,7 +219,12 @@ def _test_warnings(self, mll, ckpt):
optimizer = MockOptimizer(randomize_requires_grad=False)
optimizer.warnings = [
WarningMessage("test_runtime_warning", RuntimeWarning, __file__, 0),
WarningMessage(MAX_ITER_MSG, OptimizationWarning, __file__, 0),
WarningMessage(
"STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT",
OptimizationWarning,
__file__,
0,
),
WarningMessage(
"Optimization timed out after X", OptimizationWarning, __file__, 0
),
Expand Down Expand Up @@ -260,7 +269,9 @@ def _test_warnings(self, mll, ckpt):
{str(w.message) for w in rethrown + unresolved},
)
if logs: # test that default filter logs certain warnings
self.assertTrue(any(MAX_ITER_MSG in log for log in logs.output))
self.assertTrue(
any(MAX_ITER_MSG_REGEX.search(log) for log in logs.output)
)

# Test default of retrying upon encountering an uncaught OptimizationWarning
optimizer.warnings.append(
Expand Down
11 changes: 8 additions & 3 deletions test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.


import re
import warnings
from unittest.mock import patch

Expand All @@ -26,6 +27,12 @@
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction


MAX_ITER_MSG = re.compile(
# Note that the message changed with scipy 1.15, hence the different matching here.
"TOTAL NO. (of|OF) ITERATIONS REACHED LIMIT"
)


class SinAcqusitionFunction(MockAcquisitionFunction):
"""Simple acquisition function with known numerical properties."""

Expand Down Expand Up @@ -56,9 +63,7 @@ def closure():

with mock_optimize_context_manager():
result = scipy_minimize(closure=closure, parameters={"x": x})
self.assertEqual(
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
)
self.assertTrue(MAX_ITER_MSG.search(result.message))

with self.subTest("optimize_acqf"):
with mock_optimize_context_manager():
Expand Down

0 comments on commit eff8c03

Please sign in to comment.