-
Notifications
You must be signed in to change notification settings - Fork 159
/
Copy pathsbiutils_test.py
307 lines (244 loc) · 10 KB
/
sbiutils_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
from typing import Tuple
import matplotlib.pyplot as plt
import pytest
import torch
from torch import Tensor, ones, zeros
from torch.distributions import MultivariateNormal
from sbi.inference import SNPE
from sbi.utils import (
BoxUniform,
conditional_corrcoeff,
conditional_pairplot,
eval_conditional_density,
posterior_nn,
)
def test_conditional_density_1d():
"""
Test whether the conditional density matches analytical results for MVN.
This uses a 3D joint and conditions on the last two values to generate a 1D
conditional.
"""
joint_mean = torch.zeros(3)
joint_cov = torch.tensor([[1.0, 0.0, 0.7], [0.0, 1.0, 0.7], [0.7, 0.7, 1.0]])
joint_dist = MultivariateNormal(joint_mean, joint_cov)
condition_dim2 = torch.ones(2)
full_condition = torch.ones(3)
resolution = 100
vals_to_eval_at = torch.linspace(-3, 3, resolution).unsqueeze(1)
# Solution with sbi.
probs = eval_conditional_density(
density=joint_dist,
condition=full_condition,
limits=torch.tensor([[-3, 3], [-3, 3], [-3, 3]]),
dim1=0,
dim2=0,
resolution=resolution,
)
probs_sbi = probs / torch.sum(probs)
# Analytical solution.
conditional_mean, conditional_cov = conditional_of_mvn(
joint_mean, joint_cov, condition_dim2
)
conditional_dist = torch.distributions.MultivariateNormal(
conditional_mean, conditional_cov
)
probs = torch.exp(conditional_dist.log_prob(vals_to_eval_at))
probs_analytical = probs / torch.sum(probs)
assert torch.all(torch.abs(probs_analytical - probs_sbi) < 1e-5)
def test_conditional_density_2d():
"""
Test whether the conditional density matches analytical results for MVN.
This uses a 3D joint and conditions on the last value to generate a 2D conditional.
"""
joint_mean = torch.zeros(3)
joint_cov = torch.tensor([[1.0, 0.0, 0.7], [0.0, 1.0, 0.7], [0.7, 0.7, 1.0]])
joint_dist = MultivariateNormal(joint_mean, joint_cov)
condition_dim2 = torch.ones(1)
full_condition = torch.ones(3)
resolution = 100
vals_to_eval_at_dim1 = (
torch.linspace(-3, 3, resolution).repeat(resolution).unsqueeze(1)
)
vals_to_eval_at_dim2 = torch.repeat_interleave(
torch.linspace(-3, 3, resolution), resolution
).unsqueeze(1)
vals_to_eval_at = torch.cat((vals_to_eval_at_dim1, vals_to_eval_at_dim2), axis=1)
# Solution with sbi.
probs = eval_conditional_density(
density=joint_dist,
condition=full_condition,
limits=torch.tensor([[-3, 3], [-3, 3], [-3, 3]]),
dim1=0,
dim2=1,
resolution=resolution,
)
probs_sbi = probs / torch.sum(probs)
# Analytical solution.
conditional_mean, conditional_cov = conditional_of_mvn(
joint_mean, joint_cov, condition_dim2
)
conditional_dist = torch.distributions.MultivariateNormal(
conditional_mean, conditional_cov
)
probs = torch.exp(conditional_dist.log_prob(vals_to_eval_at))
probs = torch.reshape(probs, (resolution, resolution))
probs_analytical = probs / torch.sum(probs)
assert torch.all(torch.abs(probs_analytical - probs_sbi) < 1e-5)
def test_conditional_pairplot():
"""
This only tests whether `conditional.pairplot()` runs without errors. If does not
test its correctness. See `test_conditional_density_2d` for a test on
`eval_conditional_density`, which is the core building block of
`conditional.pairplot()`
"""
d = MultivariateNormal(
torch.tensor([0.6, 5.0]), torch.tensor([[0.1, 0.99], [0.99, 10.0]])
)
_ = conditional_pairplot(
density=d,
condition=torch.ones(1, 2),
limits=torch.tensor([[-1.0, 1.0], [-30, 30]]),
)
def conditional_of_mvn(
loc: Tensor, cov: Tensor, condition: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Return the mean and cov of a conditional Gaussian.
We assume that we always condition on the last variables.
Args:
loc: Mean of the joint distribution.
cov: Covariance matrix of the joint distribution.
condition: Condition. Should have less entries than `loc`.
"""
num_of_condition_dims = loc.shape[0] - condition.shape[0]
mean_1 = loc[:num_of_condition_dims]
mean_2 = loc[num_of_condition_dims:]
cov_11 = cov[:num_of_condition_dims, :num_of_condition_dims]
cov_12 = cov[:num_of_condition_dims, num_of_condition_dims:]
cov_22 = cov[num_of_condition_dims:, num_of_condition_dims:]
precision_observed = torch.inverse(cov_22)
residual = condition - mean_2
precision_weighted_residual = torch.einsum(
"ij, i -> j", precision_observed, residual
)
mean_shift = torch.mv(cov_12, precision_weighted_residual)
conditional_mean = mean_1 + mean_shift
prec_cov = torch.einsum("ji, kj -> ik", precision_observed, cov_12)
cov_prec_cov = torch.einsum("ij, jk -> ik", cov_12, prec_cov)
conditional_cov = cov_11 - cov_prec_cov
return conditional_mean, conditional_cov
@pytest.mark.parametrize("corr", (0.99, 0.95, 0.0))
def test_conditional_corrcoeff(corr):
"""
Test whether the conditional correlation coefficient is computed correctly.
"""
d = MultivariateNormal(
torch.tensor([0.6, 5.0]), torch.tensor([[0.1, corr], [corr, 10.0]])
)
estimated_corr = conditional_corrcoeff(
density=d,
condition=torch.ones(1, 2),
limits=torch.tensor([[-2.0, 3.0], [-70, 90]]),
resolution=500,
)[0, 1]
assert torch.abs(corr - estimated_corr) < 1e-3
def test_average_cond_coeff_matrix():
d = MultivariateNormal(
torch.tensor([10.0, 5, 1]),
torch.tensor([[100.0, 30.0, 0], [30.0, 10.0, 0], [0, 0, 1.0]]),
)
cond_mat = conditional_corrcoeff(
density=d,
condition=torch.zeros(1, 3),
limits=torch.tensor([[-60.0, 60.0], [-20, 20], [-7, 7]]),
resolution=500,
)
corr_dim12 = torch.sqrt(torch.tensor(30.0 ** 2 / 100.0 / 10.0))
gt_matrix = torch.tensor(
[[1.0, corr_dim12, 0.0], [corr_dim12, 1.0, 0.0], [0.0, 0.0, 1.0]]
)
assert (torch.abs(gt_matrix - cond_mat) < 1e-3).all()
def test_apt_transform(plot_results: bool = False):
"""
Tests whether the the product between proposal and posterior is computed correctly.
This initializes two MoGs with two components each. It then evaluates their product
by simply multiplying the probabilities of the two. The result is compared to the
product of two MoGs as implemented in APT.
Args:
plot_results: Whether to plot the products of the distributions.
"""
class MoG:
def __init__(self, means, preds, logits):
self._means = means
self._preds = preds
self._logits = logits
def log_prob(self, theta):
probs = zeros(theta.shape[0])
for m, p, l in zip(self._means, self._preds, self._logits):
mvn = MultivariateNormal(m, p)
weighted_prob = torch.exp(mvn.log_prob(theta)) * l
probs += weighted_prob
return probs
# Build a grid on which to evaluate the densities.
bound = 5.0
theta_range = torch.linspace(-bound, bound, 100)
theta1_grid, theta2_grid = torch.meshgrid(theta_range, theta_range)
theta_grid = torch.stack([theta1_grid, theta2_grid])
theta_grid_flat = torch.reshape(theta_grid, (2, 100 ** 2))
# Generate two MoGs.
means1 = torch.tensor([[2.0, 2.0], [-2.0, -2.0]])
covs1 = torch.stack([0.5 * torch.eye(2), torch.eye(2)])
weights1 = torch.tensor([0.3, 0.7])
means2 = torch.tensor([[2.0, -2.2], [-2.0, 1.9]])
covs2 = torch.stack([0.6 * torch.eye(2), 0.9 * torch.eye(2)])
weights2 = torch.tensor([0.6, 0.4])
mog1 = MoG(means1, covs1, weights1)
mog2 = MoG(means2, covs2, weights2)
# Evaluate the product of their pdfs by evaluating them separately and multiplying.
probs1_raw = mog1.log_prob(theta_grid_flat.T)
probs1 = torch.reshape(probs1_raw, (100, 100))
probs2_raw = mog2.log_prob(theta_grid_flat.T)
probs2 = torch.reshape(probs2_raw, (100, 100))
probs_mult = probs1 * probs2
# Set up a SNPE object in order to use the `_automatic_posterior_transformation()`.
prior = BoxUniform(-5 * ones(2), 5 * ones(2))
density_estimator = posterior_nn("mdn", z_score_theta=False, z_score_x=False)
inference = SNPE(prior=prior, density_estimator=density_estimator)
theta_ = torch.rand(100, 2)
x_ = torch.rand(100, 2)
_ = inference.append_simulations(theta_, x_).train(max_num_epochs=1)
inference._set_state_for_mog_proposal()
precs1 = torch.inverse(covs1)
precs2 = torch.inverse(covs2)
# `.unsqueeze(0)` is needed because the method requires a batch dimension.
logits_pp, means_pp, _, covs_pp = inference._automatic_posterior_transformation(
torch.log(weights1.unsqueeze(0)),
means1.unsqueeze(0),
precs1.unsqueeze(0),
torch.log(weights2.unsqueeze(0)),
means2.unsqueeze(0),
precs2.unsqueeze(0),
)
# Normalize weights.
logits_pp_norm = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True)
weights_pp = torch.exp(logits_pp_norm)
# Evaluate the product of the two distributions.
mog_apt = MoG(means_pp[0], covs_pp[0], weights_pp[0])
probs_apt_raw = mog_apt.log_prob(theta_grid_flat.T)
probs_apt = torch.reshape(probs_apt_raw, (100, 100))
# Compute the error between the two methods.
norm_probs_mult = probs_mult / torch.max(probs_mult)
norm_probs3_ = probs_apt / torch.max(probs_apt)
error = torch.abs(norm_probs_mult - norm_probs3_)
assert torch.max(error) < 1e-5
if plot_results:
_, ax = plt.subplots(1, 4, figsize=(16, 4))
ax[0].imshow(probs1, extent=[-bound, bound, -bound, bound])
ax[0].set_title("p_1")
ax[1].imshow(probs2, extent=[-bound, bound, -bound, bound])
ax[1].set_title("p_2")
ax[2].imshow(probs_mult, extent=[-bound, bound, -bound, bound])
ax[2].set_title("p_1 * p_2")
ax[3].imshow(probs_apt, extent=[-bound, bound, -bound, bound])
ax[3].set_title("APT")
plt.show()