Skip to content

Commit

Permalink
Add constrained synthetic test functions (#1832)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1832

Adds the following engineering-problem test functions:
- PressureVesselDesign
- WeldedBeam
- SpeedReducer
- TensionCompressionString

Reviewed By: SebastianAment

Differential Revision: D45821102

fbshipit-source-id: 22f20ffd620975d47c45fd02d58d0de8ba867f20
  • Loading branch information
Balandat authored and facebook-github-bot committed May 18, 2023
1 parent c796611 commit 323eade
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 31 deletions.
8 changes: 8 additions & 0 deletions botorch/test_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,17 @@
Levy,
Michalewicz,
Powell,
PressureVessel,
Rastrigin,
Rosenbrock,
Shekel,
SixHumpCamel,
SpeedReducer,
StyblinskiTang,
SyntheticTestFunction,
TensionCompressionString,
ThreeHumpCamel,
WeldedBeamSO,
)


Expand Down Expand Up @@ -99,17 +103,21 @@
"OSY",
"Penicillin",
"Powell",
"PressureVessel",
"Rastrigin",
"Rosenbrock",
"Shekel",
"SixHumpCamel",
"SpeedReducer",
"SRN",
"StyblinskiTang",
"SyntheticTestFunction",
"TensionCompressionString",
"ThreeHumpCamel",
"ToyRobust",
"VehicleSafety",
"WeldedBeam",
"WeldedBeamSO",
"ZDT1",
"ZDT2",
"ZDT3",
Expand Down
70 changes: 41 additions & 29 deletions botorch/test_functions/multi_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,7 +1444,10 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:

class WeldedBeam(MultiObjectiveTestProblem, ConstrainedBaseTestProblem):
r"""
The Welded Beam test problem.
The Welded Beam multi-objective test problem. Similar to `WeldedBeamSO` in
`botorch.test_function.synthetic`, but with an additional output, somewhat
modified constraints, and a different domain.
Implementation from
https://github.com/msu-coinlab/pymoo/blob/master/pymoo/problems/multi/welded_beam.py
Note that this implementation assumes minimization, so please choose negate=True.
Expand All @@ -1462,35 +1465,44 @@ class WeldedBeam(MultiObjectiveTestProblem, ConstrainedBaseTestProblem):
_ref_point = [40, 0.015]

def evaluate_true(self, X: Tensor) -> Tensor:
f1 = 1.10471 * X[..., 0] ** 2 * X[..., 1] + 0.04811 * X[..., 2] * X[..., 3] * (
14.0 + X[..., 1]
)
f2 = 2.1952 / (X[..., 3] * X[..., 2] ** 3)
# We could do the following, but the constraints are using somewhat
# different numbers (see below).
# f1 = WeldedBeam.evaluate_true(self, X)
x1, x2, x3, x4 = X.unbind(-1)
f1 = 1.10471 * (x1**2) * x2 + 0.04811 * x3 * x4 * (14.0 + x2)
f2 = 2.1952 / (x4 * x3**3)
return torch.stack([f1, f2], dim=-1)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
P = 6000
L = 14
t_max = 13600
s_max = 30000

R = torch.sqrt(0.25 * (X[..., 1] ** 2 + (X[..., 0] + X[..., 2]) ** 2))
M = P * (L + X[..., 1] / 2)
J = (
2
* math.sqrt(0.5)
* X[..., 0]
* X[..., 1]
* (X[..., 1] ** 2 / 12 + 0.25 * (X[..., 0] + X[..., 2]) ** 2)
)
t1 = P / (math.sqrt(2) * X[..., 0] * X[..., 1])
x1, x2, x3, x4 = X.unbind(-1)
P = 6000.0
L = 14.0
t_max = 13600.0
s_max = 30000.0

# Ideally, we could just do the following, but the numbers in the
# single-outcome WeldedBeam are different (see below)
# g1_, g2_, g3_, _, _, g6_ = WeldedBeam.evaluate_slack_true(self, X)
# g1 = g1_ / t_max
# g2 = g2_ / s_max
# g3 = 1 / (5 - 0.125) * g3_
# g4 = 1 / P * g6_

R = torch.sqrt(0.25 * (x2**2 + (x1 + x3) ** 2))
M = P * (L + x2 / 2)
# This `J` is different than the one in [CoelloCoello2002constraint]_
# by a factor of 2 (sqrt(2) instead of sqrt(0.5))
J = 2 * math.sqrt(0.5) * x1 * x2 * (x2**2 / 12 + 0.25 * (x1 + x3) ** 2)
t1 = P / (math.sqrt(2) * x1 * x2)
t2 = M * R / J
t = torch.sqrt(t1**2 + t2**2 + t1 * t2 * X[..., 1] / R)
s = 6 * P * L / (X[..., 3] * X[..., 2] ** 2)
P_c = 64746.022 * (1 - 0.0282346 * X[..., 2]) * X[..., 2] * X[..., 3] ** 3

g1 = (1 / t_max) * (t - t_max)
g2 = (1 / s_max) * (s - s_max)
g3 = (1 / (5 - 0.125)) * (X[..., 0] - X[..., 3])
g4 = (1 / P) * (P - P_c)
return -torch.stack([g1, g2, g3, g4], dim=-1)
t = torch.sqrt(t1**2 + t1 * t2 * x2 / R + t2**2)
s = 6 * P * L / (x4 * x3**2)
# These numbers are also different from [CoelloCoello2002constraint]_
P_c = 64746.022 * (1 - 0.0282346 * x3) * x3 * x4**3

g1 = (t - t_max) / t_max
g2 = (s - s_max) / s_max
g3 = 1 / (5 - 0.125) * (x1 - x4)
g4 = (P - P_c) / P

return torch.stack([g1, g2, g3, g4], dim=-1)
209 changes: 207 additions & 2 deletions botorch/test_functions/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,32 @@

r"""
Synthetic functions for optimization benchmarks.
Reference: https://www.sfu.ca/~ssurjano/optimization.html
Most test functions (if not indicated otherwise) are taken from
[Bingham2013virtual]_.
References:
.. [Bingham2013virtual]
D. Bingham, S. Surjanovic. Virtual Library of Simulation Experiments.
https://www.sfu.ca/~ssurjano/optimization.html
.. [CoelloCoello2002constraint]
C. A. Coello Coello and E. Mezura Montes. Constraint-handling in genetic
algorithms through the use of dominance-based tournament selection.
Advanced Engineering Informatics, 16(3):193–203, 2002.
.. [Hedar2006derivfree]
A.-R. Hedar and M. Fukushima. Derivative-free filter simulated annealing
method for constrained continuous global optimization. Journal of Global
Optimization, 35(4):521–549, 2006.
.. [Lemonge2010constrained]
A. C. C. Lemonge, H. J. C. Barbosa, C. C. H. Borges, and F. B. dos Santos
Silva. Constrained optimization problems in mechanical engineering design
using a real-coded steady-state genetic algorithm. Mecánica Computacional,
XXIX:9287–9303, 2010.
"""

from __future__ import annotations
Expand All @@ -15,7 +40,8 @@
from typing import List, Optional, Tuple

import torch
from botorch.test_functions.base import BaseTestProblem
from botorch.test_functions.base import BaseTestProblem, ConstrainedBaseTestProblem
from botorch.test_functions.utils import round_nearest
from torch import Tensor


Expand Down Expand Up @@ -761,3 +787,182 @@ class ThreeHumpCamel(SyntheticTestFunction):
def evaluate_true(self, X: Tensor) -> Tensor:
x1, x2 = X[..., 0], X[..., 1]
return 2.0 * x1**2 - 1.05 * x1**4 + x1**6 / 6.0 + x1 * x2 + x2**2


# ------------ Constrained synthetic test functions ----------- #


class PressureVessel(SyntheticTestFunction, ConstrainedBaseTestProblem):
r"""Pressure vessel design problem with constraints.
The four-dimensional pressure vessel design problem with four black-box
constraints from [CoelloCoello2002constraint]_.
"""

dim = 4
num_constraints = 4
_bounds = [(0.0, 10.0), (0.0, 10.0), (10.0, 50.0), (150.0, 200.0)]

def evaluate_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4 = X.unbind(-1)
x1 = round_nearest(x1, increment=0.0625, bounds=self._bounds[0])
x2 = round_nearest(x2, increment=0.0625, bounds=self._bounds[1])
return (
0.6224 * x1 * x3 * x4
+ 1.7781 * x2 * (x3**2)
+ 3.1661 * (x1**2) * x4
+ 19.84 * (x1**2) * x3
)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4 = X.unbind(-1)
return -torch.stack(
[
-x1 + 0.0193 * x3,
-x2 + 0.00954 * x3,
-math.pi * (x3**2) * x4 - (4 / 3) * math.pi * (x3**3) + 1296000.0,
x4 - 240.0,
],
dim=-1,
)


class WeldedBeamSO(SyntheticTestFunction, ConstrainedBaseTestProblem):
r"""Welded beam design problem with constraints (single-outcome).
The four-dimensional welded beam design proble problem with six
black-box constraints from [CoelloCoello2002constraint]_.
For a (somewhat modified) multi-objective version, see
`botorch.test_functions.multi_objective.WeldedBeam`.
"""

dim = 4
num_constraints = 6
_bounds = [(0.125, 10.0), (0.1, 10.0), (0.1, 10.0), (0.1, 10.0)]

def evaluate_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4 = X.unbind(-1)
return 1.10471 * (x1**2) * x2 + 0.04811 * x3 * x4 * (14.0 + x2)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4 = X.unbind(-1)
P = 6000.0
L = 14.0
E = 30e6
G = 12e6
t_max = 13600.0
s_max = 30000.0
d_max = 0.25

M = P * (L + x2 / 2)
R = torch.sqrt(0.25 * (x2**2 + (x1 + x3) ** 2))
J = 2 * math.sqrt(2) * x1 * x2 * (x2**2 / 12 + 0.25 * (x1 + x3) ** 2)
P_c = (
4.013
* E
* x3
* (x4**3)
* 6
/ (L**2)
* (1 - 0.25 * x3 * math.sqrt(E / G) / L)
)
t1 = P / (math.sqrt(2) * x1 * x2)
t2 = M * R / J
t = torch.sqrt(t1**2 + t1 * t2 * x2 / R + t2**2)
s = 6 * P * L / (x4 * x3**2)
d = 4 * P * L**3 / (E * x3**3 * x4)

g1 = t - t_max
g2 = s - s_max
g3 = x1 - x4
g4 = 0.10471 * x1**2 + 0.04811 * x3 * x4 * (14.0 + x2) - 5.0
g5 = d - d_max
g6 = P - P_c

return -torch.stack([g1, g2, g3, g4, g5, g6], dim=-1)


class TensionCompressionString(SyntheticTestFunction, ConstrainedBaseTestProblem):
r"""Tension compression string optimization problem with constraints.
The three-dimensional tension compression string optimization problem with
four black-box constraints from [Hedar2006derivfree]_.
"""

dim = 3
num_constraints = 4
_bounds = [(0.01, 1.0), (0.01, 1.0), (0.01, 20.0)]

def evaluate_true(self, X: Tensor) -> Tensor:
x1, x2, x3 = X.unbind(-1)
return (x1**2) * x2 * (x3 + 2)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
x1, x2, x3 = X.unbind(-1)
constraints = torch.stack(
[
1 - (x2**3) * x3 / (71785 * (x1**4)),
(4 * (x2**2) - x1 * x2) / (12566 * (x1**3) * (x2 - x1))
+ 1 / (5108 * (x1**2))
- 1,
1 - 140.45 * x1 / (x3 * (x2**2)),
(x1 + x2) / 1.5 - 1,
],
dim=-1,
)
return -constraints.clamp_max(100)


class SpeedReducer(SyntheticTestFunction, ConstrainedBaseTestProblem):
r"""Speed Reducer design problem with constraints.
The seven-dimensional speed reducer design problem with eleven black-box
constraints from [Lemonge2010constrained]_.
"""

dim = 7
num_constraints = 11
_bounds = [
(2.6, 3.6),
(0.7, 0.8),
(17.0, 28.0),
(7.3, 8.3),
(7.8, 8.3),
(2.9, 3.9),
(5.0, 5.5),
]

def evaluate_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4, x5, x6, x7 = X.unbind(-1)
return (
0.7854 * x1 * (x2**2) * (3.3333 * (x3**2) + 14.9334 * x3 - 43.0934)
+ -1.508 * x1 * (x6**2 + x7**2)
+ 7.4777 * (x6**3 + x7**3)
+ 0.7854 * (x4 * (x6**2) + x5 * (x7**2))
)

def evaluate_slack_true(self, X: Tensor) -> Tensor:
x1, x2, x3, x4, x5, x6, x7 = X.unbind(-1)
return -torch.stack(
[
27.0 * (1 / x1) * (1 / (x2**2)) * (1 / x3) - 1,
397.5 * (1 / x1) * (1 / (x2**2)) * (1 / (x3**2)) - 1,
1.93 * (1 / x2) * (1 / x3) * (x4**3) * (1 / (x6**4)) - 1,
1.93 * (1 / x2) * (1 / x3) * (x5**3) * (1 / (x7**4)) - 1,
1
/ (0.1 * (x6**3))
* torch.sqrt((745 * x4 / (x2 * x3)) ** 2 + 16.9 * 1e6)
- 1100,
1
/ (0.1 * (x7**3))
* torch.sqrt((745 * x5 / (x2 * x3)) ** 2 + 157.5 * 1e6)
- 850,
x2 * x3 - 40,
5 - x1 / x2,
x1 / x2 - 12,
(1.5 * x6 + 1.9) / x4 - 1,
(1.1 * x7 + 1.9) / x5 - 1,
],
dim=-1,
)
36 changes: 36 additions & 0 deletions botorch/test_functions/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from __future__ import annotations

from typing import Optional, Tuple

import torch

from torch import Tensor


def round_nearest(
X: Tensor, increment: float, bounds: Optional[Tuple[float, float]]
) -> Tensor:
r"""Rounds the input tensor to the nearest multiple of `increment`.
Args:
X: The input to be rounded.
increment: The increment to round to.
bounds: An optional tuple of two floats representing the lower and upper
bounds on `X`. If provided, this will round to the nearest multiple
of `increment` that lies within the bounds.
Returns:
The rounded input.
"""
X_round = torch.round(X / increment) * increment
if bounds is not None:
X_round = torch.where(X_round < bounds[0], X_round + increment, X_round)
X_round = torch.where(X_round > bounds[1], X_round - increment, X_round)
return X_round
5 changes: 5 additions & 0 deletions sphinx/source/test_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,8 @@ Sensitivity Analysis Test Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.test_functions.sensitivity_analysis
:members:

Utilities For Test Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.test_functions.utils
:members:
Loading

0 comments on commit 323eade

Please sign in to comment.