-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathConstraintsOp.py
160 lines (130 loc) · 7.12 KB
/
ConstraintsOp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Operator enforcing constraints by variable transformations."""
from collections.abc import Sequence
import torch
import torch.nn.functional as F # noqa: N812
from mrpro.operators.EndomorphOperator import EndomorphOperator, endomorph
class ConstraintsOp(EndomorphOperator):
"""Transformation to map real-valued tensors to certain ranges."""
def __init__(
self,
bounds: Sequence[tuple[float | None, float | None]],
beta_sigmoid: float = 1.0,
beta_softplus: float = 1.0,
) -> None:
"""Initialize a constraint operator.
The operator maps real-valued tensors to certain ranges. The transformation is applied element-wise.
The transformation is defined by the bounds. The bounds are applied in the order of the input tensors.
If there are more input tensors than bounds, the remaining tensors are passed through without transformation.
If an input tensor is bounded from below AND above, a sigmoid transformation is applied.
If an input tensor is bounded from below OR above, a softplus transformation is applied.
Parameters
----------
bounds
Sequence of (lower_bound, upper_bound) values. If a bound is None, the value is not constrained.
If a lower bound is -inf, the value is not constrained from below. If an upper bound is inf,
the value is not constrained from above.
If the bounds are set to (None, None) or (-inf, inf), the value is not constrained at all.
beta_sigmoid
beta parameter for the sigmoid transformation (used an input has two bounds).
A higher value leads to a steeper sigmoid.
beta_softplus
parameter for the softplus transformation (used if an input is either bounded from below or above).
A higher value leads to a steeper softplus.
"""
super().__init__()
if beta_sigmoid <= 0:
raise ValueError(f'parameter beta_sigmoid must be greater than zero; given {beta_sigmoid}')
if beta_softplus <= 0:
raise ValueError(f'parameter beta_softplus must be greater than zero; given {beta_softplus}')
self.beta_sigmoid = beta_sigmoid
self.beta_softplus = beta_softplus
self.lower_bounds = [bound[0] for bound in bounds]
self.upper_bounds = [bound[1] for bound in bounds]
for lb, ub in bounds:
if lb is not None and ub is not None:
if torch.isnan(torch.tensor(lb)) or torch.isnan(torch.tensor(ub)):
raise ValueError(' "nan" is not a valid lower or upper bound;' f'\nbound tuple {lb, ub} is invalid')
if lb >= ub:
raise ValueError(
'bounds should be ( (a1,b1), (a2,b2), ...) with ai < bi if neither ai or bi is None;'
f'\nbound tuple {lb, ub} is invalid',
)
@staticmethod
def sigmoid(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
"""Constraint x to be in the range given by 'bounds'."""
return F.sigmoid(beta * x)
@staticmethod
def sigmoid_inverse(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
"""Constraint x to be in the range given by 'bounds'."""
return torch.logit(x) / beta
@staticmethod
def softplus(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
"""Constrain x to be in (bound,infty)."""
return -(1 / beta) * torch.nn.functional.logsigmoid(-beta * x)
@staticmethod
def softplus_inverse(x: torch.Tensor, beta: float = 1.0) -> torch.Tensor:
"""Inverse of 'softplus_transformation."""
return beta * x + torch.log(-torch.expm1(-beta * x))
@endomorph
def forward(self, *x: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Transform tensors to chosen range.
Parameters
----------
x
tensors to be transformed
Returns
-------
tensors transformed to the range defined by the chosen bounds
"""
x_constrained = []
for item, lb, ub in zip(x, self.lower_bounds, self.upper_bounds, strict=False):
# distinguish cases
if (lb is not None and not torch.isneginf(torch.tensor(lb))) and (
ub is not None and not torch.isposinf(torch.tensor(ub))
):
# case (a,b) with a<b and a,b \in R
x_constrained.append(lb + (ub - lb) * self.sigmoid(item, beta=self.beta_sigmoid))
elif lb is not None and (ub is None or torch.isposinf(torch.tensor(ub))):
# case (a,None); corresponds to (a, \infty)
x_constrained.append(lb + self.softplus(item, beta=self.beta_softplus))
elif (lb is None or torch.isneginf(torch.tensor(lb))) and ub is not None:
# case (None,b); corresponds to (-\infty, b)
x_constrained.append(ub - self.softplus(-item, beta=self.beta_softplus))
elif (lb is None or torch.isneginf(torch.tensor(lb))) and (ub is None or torch.isposinf(torch.tensor(ub))):
# case (None,None); corresponds to (-\infty, \infty), i.e. no transformation
x_constrained.append(item)
# if there are more inputs than bounds, pass on the remaining inputs without transformation
x_constrained.extend(x[len(x_constrained) :])
return tuple(x_constrained)
def inverse(self, *x_constrained: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""Reverses the variable transformation.
Parameters
----------
x_constrained
transformed tensors with values in the range defined by the bounds
Returns
-------
tensors in the domain with no bounds
"""
# iterate over the tensors and constrain them if necessary according to the
# chosen bounds
x = []
for item, lb, ub in zip(x_constrained, self.lower_bounds, self.upper_bounds, strict=False):
# distinguish cases
if (lb is not None and not torch.isneginf(torch.tensor(lb))) and (
ub is not None and not torch.isposinf(torch.tensor(ub))
):
# case (a,b) with a<b and a,b \in R
x.append(self.sigmoid_inverse((item - lb) / (ub - lb), beta=self.beta_sigmoid))
elif lb is not None and (ub is None or torch.isposinf(torch.tensor(ub))):
# case (a,None); corresponds to (a, \infty)
x.append(self.softplus_inverse(item - lb, beta=self.beta_softplus))
elif (lb is None or torch.isneginf(torch.tensor(lb))) and ub is not None:
# case (None,b); corresponds to (-\infty, b)
x.append(-self.softplus_inverse(-(item - ub), beta=self.beta_softplus))
elif (lb is None or torch.isneginf(torch.tensor(lb))) and (ub is None or torch.isposinf(torch.tensor(ub))):
# case (None,None); corresponds to (-\infty, \infty), i.e. no transformation
x.append(item)
# if there are more inputs than bounds, pass on the remaining inputs without transformation
x.extend(x_constrained[len(x) :])
return tuple(x)