@@ -47,30 +47,78 @@ def orientation(simplex):
47
47
return sign
48
48
49
49
50
- def uniform_loss (simplex , ys = None ):
50
+ def uniform_loss (simplex , values , value_scale ):
51
+ """
52
+ Uniform loss.
53
+
54
+ Parameters
55
+ ----------
56
+ simplex : list of tuples
57
+ Each entry is one point of the simplex.
58
+ values : list of values
59
+ The scaled function values of each of the simplex points.
60
+ value_scale : float
61
+ The scale of values, where ``values = function_values * value_scale``.
62
+
63
+ Returns
64
+ -------
65
+ loss : float
66
+ """
51
67
return volume (simplex )
52
68
53
69
54
- def std_loss (simplex , ys ):
55
- r = np .linalg .norm (np .std (ys , axis = 0 ))
70
+ def std_loss (simplex , values , value_scale ):
71
+ """
72
+ Computes the loss of the simplex based on the standard deviation.
73
+
74
+ Parameters
75
+ ----------
76
+ simplex : list of tuples
77
+ Each entry is one point of the simplex.
78
+ values : list of values
79
+ The scaled function values of each of the simplex points.
80
+ value_scale : float
81
+ The scale of values, where ``values = function_values * value_scale``.
82
+
83
+ Returns
84
+ -------
85
+ loss : float
86
+ """
87
+
88
+ r = np .linalg .norm (np .std (values , axis = 0 ))
56
89
vol = volume (simplex )
57
90
58
91
dim = len (simplex ) - 1
59
92
60
93
return r .flat * np .power (vol , 1.0 / dim ) + vol
61
94
62
95
63
- def default_loss (simplex , ys ):
64
- # return std_loss(simplex, ys)
65
- if isinstance (ys [0 ], Iterable ):
66
- pts = [(* x , * y ) for x , y in zip (simplex , ys )]
96
+ def default_loss (simplex , values , value_scale ):
97
+ """
98
+ Computes the average of the volumes of the simplex.
99
+
100
+ Parameters
101
+ ----------
102
+ simplex : list of tuples
103
+ Each entry is one point of the simplex.
104
+ values : list of values
105
+ The scaled function values of each of the simplex points.
106
+ value_scale : float
107
+ The scale of values, where ``values = function_values * value_scale``.
108
+
109
+ Returns
110
+ -------
111
+ loss : float
112
+ """
113
+ if isinstance (values [0 ], Iterable ):
114
+ pts = [(* x , * y ) for x , y in zip (simplex , values )]
67
115
else :
68
- pts = [(* x , y ) for x , y in zip (simplex , ys )]
116
+ pts = [(* x , y ) for x , y in zip (simplex , values )]
69
117
return simplex_volume_in_embedding (pts )
70
118
71
119
72
120
@uses_nth_neighbors (1 )
73
- def triangle_loss (simplex , values , neighbors , neighbor_values ):
121
+ def triangle_loss (simplex , values , value_scale , neighbors , neighbor_values ):
74
122
"""
75
123
Computes the average of the volumes of the simplex combined with each
76
124
neighbouring point.
@@ -80,7 +128,9 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
80
128
simplex : list of tuples
81
129
Each entry is one point of the simplex.
82
130
values : list of values
83
- The function values of each of the simplex points.
131
+ The scaled function values of each of the simplex points.
132
+ value_scale : float
133
+ The scale of values, where ``values = function_values * value_scale``.
84
134
neighbors : list of tuples
85
135
The neighboring points of the simplex, ordered such that simplex[0]
86
136
exacly opposes neighbors[0], etc.
@@ -108,20 +158,22 @@ def triangle_loss(simplex, values, neighbors, neighbor_values):
108
158
def curvature_loss_function (exploration = 0.05 ):
109
159
# XXX: add doc-string!
110
160
@uses_nth_neighbors (1 )
111
- def curvature_loss (simplex , values , neighbors , neighbor_values ):
161
+ def curvature_loss (simplex , values , value_scale , neighbors , neighbor_values ):
112
162
"""Compute the curvature loss of a simplex.
113
163
114
164
Parameters
115
165
----------
116
166
simplex : list of tuples
117
167
Each entry is one point of the simplex.
118
168
values : list of values
119
- The function values of each of the simplex points.
169
+ The scaled function values of each of the simplex points.
170
+ value_scale : float
171
+ The scale of values, where ``values = function_values * value_scale``.
120
172
neighbors : list of tuples
121
173
The neighboring points of the simplex, ordered such that simplex[0]
122
174
exacly opposes neighbors[0], etc.
123
175
neighbor_values : list of values
124
- The function values for each of the neighboring points.
176
+ The scaled function values for each of the neighboring points.
125
177
126
178
Returns
127
179
-------
@@ -130,7 +182,9 @@ def curvature_loss(simplex, values, neighbors, neighbor_values):
130
182
dim = len (simplex [0 ]) # the number of coordinates
131
183
loss_input_volume = volume (simplex )
132
184
133
- loss_curvature = triangle_loss (simplex , values , neighbors , neighbor_values )
185
+ loss_curvature = triangle_loss (
186
+ simplex , values , value_scale , neighbors , neighbor_values
187
+ )
134
188
return (
135
189
loss_curvature + exploration * loss_input_volume ** ((2 + dim ) / dim )
136
190
) ** (1 / (2 + dim ))
@@ -563,7 +617,9 @@ def _compute_loss(self, simplex):
563
617
564
618
if self .nth_neighbors == 0 :
565
619
# compute the loss on the scaled simplex
566
- return float (self .loss_per_simplex (vertices , values ))
620
+ return float (
621
+ self .loss_per_simplex (vertices , values , self ._output_multiplier )
622
+ )
567
623
568
624
# We do need the neighbors
569
625
neighbors = self .tri .get_opposing_vertices (simplex )
@@ -580,7 +636,13 @@ def _compute_loss(self, simplex):
580
636
neighbor_values [i ] = self ._output_multiplier * value
581
637
582
638
return float (
583
- self .loss_per_simplex (vertices , values , neighbor_points , neighbor_values )
639
+ self .loss_per_simplex (
640
+ vertices ,
641
+ values ,
642
+ self ._output_multiplier ,
643
+ neighbor_points ,
644
+ neighbor_values ,
645
+ )
584
646
)
585
647
586
648
def _update_losses (self , to_delete : set , to_add : set ):
0 commit comments