-
Notifications
You must be signed in to change notification settings - Fork 40
/
vegas_map.py
277 lines (245 loc) · 10.9 KB
/
vegas_map.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from autoray import numpy as anp
from autoray import astype, to_backend_dtype
from loguru import logger
from .utils import _add_at_indices
class VEGASMap:
"""The map used for VEGAS Enhanced. Refer to https://arxiv.org/abs/2009.05112 .
Implementation is inspired by https://github.com/ycwu1030/CIGAR/ .
EQ <n> refers to equation <n> in the above paper.
"""
def __init__(self, N_intervals, dim, backend, dtype, alpha=0.5) -> None:
"""Initializes VEGAS Enhanced's adaptive map
Args:
N_intervals (int): Number of intervals per dimension to split the domain in.
dim (int): Dimensionality of the integration domain.
backend (string): Numerical backend
dtype (backend dtype): dtype used for the calculations
alpha (float, optional): Alpha from the paper, EQ 19. Defaults to 0.5.
"""
self.dim = dim
self.N_intervals = N_intervals # # of subdivisions
N_edges = self.N_intervals + 1 # # of subdivsion boundaries
self.alpha = alpha # Weight smoothing
self.backend = backend
self.dtype = dtype
# Boundary locations x_edges and subdomain stepsizes dx_edges
# Subdivide the domain [0,1]^dim equally spaced in N-d, EQ 8
self.dx_edges = (
anp.ones((self.dim, self.N_intervals), dtype=self.dtype, like=self.backend)
/ self.N_intervals
)
x_edges_per_dim = anp.linspace(
0.0, 1.0, N_edges, dtype=self.dtype, like=self.backend
)
self.x_edges = anp.repeat(
anp.reshape(x_edges_per_dim, [1, N_edges]), self.dim, axis=0
)
# Initialize self.weights and self.counts
self._reset_weight()
def get_X(self, y):
"""Get mapped sampling points, EQ 9.
Args:
y (backend tensor): Randomly sampled location(s)
Returns:
backend tensor: Mapped points.
"""
ID, offset = self._get_interval_ID(y), self._get_interval_offset(y)
res = anp.zeros_like(y)
for i in range(self.dim):
ID_i = ID[:, i]
res[:, i] = self.x_edges[i, ID_i] + self.dx_edges[i, ID_i] * offset[:, i]
return res
def get_Jac(self, y):
"""Computes the jacobian of the mapping transformation, EQ 12.
Args:
y ([type]): Sampled locations.
Returns:
backend tensor: Jacobian
"""
ID = self._get_interval_ID(y)
jac = anp.ones([y.shape[0]], dtype=y.dtype, like=y)
for i in range(self.dim):
ID_i = ID[:, i]
jac *= self.N_intervals * self.dx_edges[i][ID_i]
return jac
def _get_interval_ID(self, y):
"""Get the integer part of the desired mapping , EQ 10.
Args:
y (backend tensor): Sampled points
Returns:
backend tensor: Integer part of mapped points.
"""
return astype(anp.floor(y * float(self.N_intervals)), "int64")
def _get_interval_offset(self, y):
"""Get the fractional part of the desired mapping , EQ 11.
Args:
y (backend tensor): Sampled points.
Returns:
backend tensor: Fractional part of mapped points.
"""
y = y * float(self.N_intervals)
return y - anp.floor(y)
def accumulate_weight(self, y, jf_vec2):
"""Accumulate weights and counts of the map.
Args:
y (backend tensor): Sampled points.
jf_vec2 (backend tensor): Square of the product of function values and jacobians
"""
ones = anp.ones(jf_vec2.shape, dtype=self.counts.dtype, like=jf_vec2)
ID = self._get_interval_ID(y)
for i in range(self.dim):
ID_i = ID[:, i]
_add_at_indices(self.weights[i], ID_i, jf_vec2)
_add_at_indices(self.counts[i], ID_i, ones)
@staticmethod
def _smooth_map(weights, counts, alpha):
"""Smooth the weights in the map, EQ 18 - 22."""
# Get the average values for J^2 f^2 (weights)
# EQ 17
z_idx = counts == 0 # zero count indices
if anp.any(z_idx):
nnz_idx = anp.logical_not(z_idx)
weights[nnz_idx] /= counts[nnz_idx]
logger.opt(lazy=True).debug(
"The integrand was not evaluated in {z_idx_sum} of {num_weights} VEGASMap intervals. "
"Filling the weights for some of them with neighbouring values.",
z_idx_sum=lambda: anp.sum(z_idx),
num_weights=lambda: counts.shape[0] * counts.shape[1],
)
# Set the weights of the intervals with zero count to weights from
# their nearest neighbouring intervals
# (up to a distance of 10 indices).
for _ in range(10):
weights[:, :-1] = anp.where(
z_idx[:, :-1], weights[:, 1:], weights[:, :-1]
)
# The asterisk corresponds to a logical And here
z_idx[:, :-1] = z_idx[:, :-1] * z_idx[:, 1:]
weights[:, 1:] = anp.where(
z_idx[:, 1:], weights[:, :-1], weights[:, 1:]
)
z_idx[:, 1:] = z_idx[:, 1:] * z_idx[:, :-1]
logger.opt(lazy=True).debug(
" remaining intervals: {z_idx_sum}",
z_idx_sum=lambda: anp.sum(z_idx),
)
if not anp.any(z_idx):
break
else:
weights /= counts
# Convolve with [1/8, 6/8, 1/8] in each dimension to smooth the
# weights; boundary behaviour: repeat border values.
# Divide by d_sum to normalize (divide by the sum before smoothing)
# EQ 18
dim, N_intervals = weights.shape
weights_sums = anp.reshape(anp.sum(weights, axis=1), [dim, 1])
if anp.any(weights_sums == 0.0):
# The VEGASMap cannot be updated in dimensions where all weights
# are zero.
return None
i_tmp = N_intervals - 2
d_tmp = anp.concatenate(
[
7.0 * weights[:, 0:1] + weights[:, 1:2],
weights[:, :-2] + 6.0 * weights[:, 1:-1] + weights[:, 2:],
weights[:, i_tmp : i_tmp + 1] + 7.0 * weights[:, i_tmp + 1 : i_tmp + 2],
],
axis=1,
like=weights,
)
d_tmp = d_tmp / (8.0 * weights_sums)
# Range compression
# EQ 19
d_tmp[d_tmp != 0] = (
(d_tmp[d_tmp != 0] - 1.0) / anp.log(d_tmp[d_tmp != 0])
) ** alpha
return d_tmp
def _reset_weight(self):
"""Reset or initialize weights and counts."""
# weights in each intervall
self.weights = anp.zeros(
(self.dim, self.N_intervals), dtype=self.dtype, like=self.backend
)
# numbers of random samples in specific interval
self.counts = anp.zeros(
(self.dim, self.N_intervals),
dtype=to_backend_dtype("int64", like=self.backend),
like=self.backend,
)
def update_map(self):
"""Update the adaptive map, Section II C."""
smoothed_weights = self._smooth_map(self.weights, self.counts, self.alpha)
if smoothed_weights is None:
logger.warning(
"Cannot update the VEGASMap. This can happen with an integrand "
"which evaluates to zero everywhere."
)
self._reset_weight()
return
# The amount of the sum of smoothed_weights for each interval of
# the new 1D grid, for each dimension
# EQ 20
delta_weights = anp.sum(smoothed_weights, axis=1) / self.N_intervals
for i in range(self.dim): # Update per dim
delta_d = delta_weights[i]
# For each inner edge, determine how many delta_d fit into the
# accumulated smoothed weights.
# With torch, CUDA and a high number of points the cumsum operation
# with float32 precision is too inaccurate which leads to wrong
# indices, so cast to float64 here.
delta_d_multiples = astype(
anp.cumsum(astype(smoothed_weights[i, :-1], "float64"), axis=0)
/ delta_d,
"int64",
)
# For each number of delta_d multiples in {0, 1, …, N_intervals},
# determine how many intervals belong to it (num_sw_per_dw)
# and the sum of smoothed weights in these intervals (val_sw_per_dw)
dtype_int = delta_d_multiples.dtype
num_sw_per_dw = anp.zeros(
[self.N_intervals + 1], dtype=dtype_int, like=delta_d
)
_add_at_indices(
num_sw_per_dw,
delta_d_multiples,
anp.ones(delta_d_multiples.shape, dtype=dtype_int, like=delta_d),
is_sorted=True,
)
val_sw_per_dw = anp.zeros(
[self.N_intervals + 1], dtype=self.dtype, like=delta_d
)
_add_at_indices(
val_sw_per_dw, delta_d_multiples, smoothed_weights[i], is_sorted=True
)
# The cumulative sum of the number of smoothed weights per delta_d
# multiple determines the old inner edges indices for the new inner
# edges calculation
indices = anp.cumsum(num_sw_per_dw[:-2], axis=0)
# d_accu_i is used for the interpolation in the new inner edges
# calculation when adding it to the old inner edges
d_accu_i = anp.cumsum(delta_d - val_sw_per_dw[:-2], axis=0)
# EQ 22
self.x_edges[i][1:-1] = (
self.x_edges[i][indices]
+ d_accu_i / smoothed_weights[i][indices] * self.dx_edges[i][indices]
)
finite_edges = anp.isfinite(self.x_edges[i])
if not anp.all(finite_edges):
# With float64 precision the delta_d_multiples calculation
# usually doesn't have rounding errors.
# If it is nonetheless too inaccurate, few values in
# smoothed_weights[i][indices] can be zero, which leads to
# invalid edges.
num_edges = self.x_edges.shape[1]
logger.warning(
f"{num_edges - anp.sum(finite_edges)} out of {num_edges} calculated VEGASMap edges were infinite"
)
# Replace inf edges with the average of their two neighbours
middle_edges = 0.5 * (self.x_edges[i][:-2] + self.x_edges[i][2:])
self.x_edges[i][1:-1] = anp.where(
finite_edges[1:-1], self.x_edges[i][1:-1], middle_edges
)
if not anp.all(anp.isfinite(self.x_edges[i])):
raise RuntimeError("Could not replace all infinite edges")
self.dx_edges[i] = self.x_edges[i][1:] - self.x_edges[i][:-1]
self._reset_weight()