Skip to content

Commit ca76230

Browse files
committed
pass value_scale to the LearnerND's loss_per_simplex function
1 parent d9fc5dd commit ca76230

File tree

1 file changed

+78
-16
lines changed

1 file changed

+78
-16
lines changed

adaptive/learner/learnerND.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,78 @@ def orientation(simplex):
4747
return sign
4848

4949

50-
def uniform_loss(simplex, ys=None):
50+
def uniform_loss(simplex, values, value_scale):
51+
"""
52+
Uniform loss.
53+
54+
Parameters
55+
----------
56+
simplex : list of tuples
57+
Each entry is one point of the simplex.
58+
values : list of values
59+
The scaled function values of each of the simplex points.
60+
value_scale : float
61+
The scale of values, where ``values = function_values * value_scale``.
62+
63+
Returns
64+
-------
65+
loss : float
66+
"""
5167
return volume(simplex)
5268

5369

54-
def std_loss(simplex, ys):
55-
r = np.linalg.norm(np.std(ys, axis=0))
70+
def std_loss(simplex, values, value_scale):
71+
"""
72+
Computes the loss of the simplex based on the standard deviation.
73+
74+
Parameters
75+
----------
76+
simplex : list of tuples
77+
Each entry is one point of the simplex.
78+
values : list of values
79+
The scaled function values of each of the simplex points.
80+
value_scale : float
81+
The scale of values, where ``values = function_values * value_scale``.
82+
83+
Returns
84+
-------
85+
loss : float
86+
"""
87+
88+
r = np.linalg.norm(np.std(values, axis=0))
5689
vol = volume(simplex)
5790

5891
dim = len(simplex) - 1
5992

6093
return r.flat * np.power(vol, 1.0 / dim) + vol
6194

6295

63-
def default_loss(simplex, ys):
64-
# return std_loss(simplex, ys)
65-
if isinstance(ys[0], Iterable):
66-
pts = [(*x, *y) for x, y in zip(simplex, ys)]
96+
def default_loss(simplex, values, value_scale):
97+
"""
98+
Computes the average of the volumes of the simplex.
99+
100+
Parameters
101+
----------
102+
simplex : list of tuples
103+
Each entry is one point of the simplex.
104+
values : list of values
105+
The scaled function values of each of the simplex points.
106+
value_scale : float
107+
The scale of values, where ``values = function_values * value_scale``.
108+
109+
Returns
110+
-------
111+
loss : float
112+
"""
113+
if isinstance(values[0], Iterable):
114+
pts = [(*x, *y) for x, y in zip(simplex, values)]
67115
else:
68-
pts = [(*x, y) for x, y in zip(simplex, ys)]
116+
pts = [(*x, y) for x, y in zip(simplex, values)]
69117
return simplex_volume_in_embedding(pts)
70118

71119

72120
@uses_nth_neighbors(1)
73-
def triangle_loss(simplex, values, neighbors, neighbor_values):
121+
def triangle_loss(simplex, values, value_scale, neighbors, neighbor_values):
74122
"""
75123
Computes the average of the volumes of the simplex combined with each
76124
neighbouring point.
@@ -80,7 +128,9 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
80128
simplex : list of tuples
81129
Each entry is one point of the simplex.
82130
values : list of values
83-
The function values of each of the simplex points.
131+
The scaled function values of each of the simplex points.
132+
value_scale : float
133+
The scale of values, where ``values = function_values * value_scale``.
84134
neighbors : list of tuples
85135
The neighboring points of the simplex, ordered such that simplex[0]
86136
exacly opposes neighbors[0], etc.
@@ -108,20 +158,22 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
108158
def curvature_loss_function(exploration=0.05):
109159
# XXX: add doc-string!
110160
@uses_nth_neighbors(1)
111-
def curvature_loss(simplex, values, neighbors, neighbor_values):
161+
def curvature_loss(simplex, values, value_scale, neighbors, neighbor_values):
112162
"""Compute the curvature loss of a simplex.
113163
114164
Parameters
115165
----------
116166
simplex : list of tuples
117167
Each entry is one point of the simplex.
118168
values : list of values
119-
The function values of each of the simplex points.
169+
The scaled function values of each of the simplex points.
170+
value_scale : float
171+
The scale of values, where ``values = function_values * value_scale``.
120172
neighbors : list of tuples
121173
The neighboring points of the simplex, ordered such that simplex[0]
122174
exacly opposes neighbors[0], etc.
123175
neighbor_values : list of values
124-
The function values for each of the neighboring points.
176+
The scaled function values for each of the neighboring points.
125177
126178
Returns
127179
-------
@@ -130,7 +182,9 @@ def curvature_loss(simplex, values, neighbors, neighbor_values):
130182
dim = len(simplex[0]) # the number of coordinates
131183
loss_input_volume = volume(simplex)
132184

133-
loss_curvature = triangle_loss(simplex, values, neighbors, neighbor_values)
185+
loss_curvature = triangle_loss(
186+
simplex, values, value_scale, neighbors, neighbor_values
187+
)
134188
return (
135189
loss_curvature + exploration * loss_input_volume ** ((2 + dim) / dim)
136190
) ** (1 / (2 + dim))
@@ -563,7 +617,9 @@ def _compute_loss(self, simplex):
563617

564618
if self.nth_neighbors == 0:
565619
# compute the loss on the scaled simplex
566-
return float(self.loss_per_simplex(vertices, values))
620+
return float(
621+
self.loss_per_simplex(vertices, values, self._output_multiplier)
622+
)
567623

568624
# We do need the neighbors
569625
neighbors = self.tri.get_opposing_vertices(simplex)
@@ -580,7 +636,13 @@ def _compute_loss(self, simplex):
580636
neighbor_values[i] = self._output_multiplier * value
581637

582638
return float(
583-
self.loss_per_simplex(vertices, values, neighbor_points, neighbor_values)
639+
self.loss_per_simplex(
640+
vertices,
641+
values,
642+
self._output_multiplier,
643+
neighbor_points,
644+
neighbor_values,
645+
)
584646
)
585647

586648
def _update_losses(self, to_delete: set, to_add: set):

0 commit comments

Comments
 (0)