Skip to content

Commit 368ea2b

Browse files
committed
reduce sample size to speed up tests
1 parent 3cb0041 commit 368ea2b

File tree

1 file changed

+35
-31
lines changed

1 file changed

+35
-31
lines changed

tests/test_samples.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def seed():
1515
@pytest.fixture(scope="module")
1616
def normal_samples(seed):
1717
rng = np.random.default_rng(seed=seed)
18-
n_samples = 10**4
18+
n_samples = 10**3
1919
loc = 1
20-
scale = np.sqrt(2)
20+
scale = 1
2121
return {
2222
"samples": rng.normal(loc=loc, scale=scale, size=(n_samples, 15, 1)),
2323
"loc": loc,
@@ -28,7 +28,7 @@ def normal_samples(seed):
2828
@pytest.fixture(scope="module")
2929
def uniform_samples(seed):
3030
rng = np.random.default_rng(seed=seed)
31-
n_samples = 10**4
31+
n_samples = 10**3
3232
low = 0
3333
high = 1
3434
return {
@@ -115,7 +115,9 @@ def kurtosis(samples):
115115

116116
assert sample_kurtosis.shape == (batch, y_dim)
117117
for i in range(batch):
118-
assert np.allclose(sample_kurtosis[i, :], true_kurtosis, atol=0.1)
118+
assert np.allclose(
119+
sample_kurtosis[i, :], true_kurtosis, atol=0.5
120+
), f"sample_kurtosis={sample_kurtosis[i, :]}, true={true_kurtosis}"
119121

120122

121123
def test_samples_confidence_interval_and_quantiles(normal_samples):
@@ -134,10 +136,10 @@ def test_samples_confidence_interval_and_quantiles(normal_samples):
134136
assert sample_q_025.shape == (1, samples.batch, samples.y_dim)
135137
assert sample_confidence_interval.shape == (2, samples.batch, samples.y_dim)
136138
for i in range(samples.batch):
137-
assert np.allclose(sample_q_975[:, i, :], true_q_975, atol=0.1)
138-
assert np.allclose(sample_q_025[:, i, :], true_q_025, atol=0.1)
139+
assert np.allclose(sample_q_975[:, i, :], true_q_975, atol=0.3)
140+
assert np.allclose(sample_q_025[:, i, :], true_q_025, atol=0.3)
139141
assert np.allclose(
140-
sample_confidence_interval[:, i, :].reshape(-1), true_confidence_interval, atol=0.1
142+
sample_confidence_interval[:, i, :].reshape(-1), true_confidence_interval, atol=0.3
141143
)
142144

143145

@@ -152,35 +154,37 @@ def test_samples_correlation(multivariate_normal_samples):
152154
assert np.allclose(sample_correlation[i, :, :], true_correlation, atol=0.1)
153155

154156

155-
@pytest.mark.parametrize(
156-
"statistic, true_value",
157-
[
158-
("sample_mean", 1),
159-
("sample_median", 1),
160-
("sample_mode", 1),
161-
("sample_std", np.sqrt(2)),
162-
],
163-
)
164-
def test_samples_main_statistics(statistic, true_value, normal_samples):
157+
def test_samples_main_statistics(normal_samples):
158+
true_values = {
159+
"sample_mean": normal_samples["loc"],
160+
"sample_median": normal_samples["loc"],
161+
"sample_mode": normal_samples["loc"],
162+
"sample_std": normal_samples["scale"],
163+
}
164+
165165
samples = Samples(normal_samples["samples"])
166166
batch = normal_samples["samples"].shape[1]
167167

168-
sample_stat = getattr(samples, statistic)()
169-
for i in range(batch):
170-
assert np.allclose(sample_stat[i, ...], true_value, atol=0.1)
168+
for statistic, true_value in true_values.items():
169+
sample_stat = getattr(samples, statistic)()
170+
for i in range(batch):
171+
assert np.allclose(
172+
sample_stat[i, ...], true_value, atol=0.1
173+
), f"{statistic}={sample_stat[i, ...]} vs. true={true_value}"
174+
171175

176+
def test_samples_max_min(uniform_samples):
177+
true_values = {
178+
"sample_max": uniform_samples["high"],
179+
"sample_min": uniform_samples["low"],
180+
}
172181

173-
@pytest.mark.parametrize(
174-
"statistic, true_value",
175-
[
176-
("sample_max", 1),
177-
("sample_min", 0),
178-
],
179-
)
180-
def test_samples_max_min(statistic, true_value, uniform_samples):
181182
samples = Samples(uniform_samples["samples"])
182183
batch = uniform_samples["samples"].shape[1]
183184

184-
sample_stat = getattr(samples, statistic)()
185-
for i in range(batch):
186-
assert np.allclose(sample_stat[i, ...], true_value, atol=0.1)
185+
for statistic, true_value in true_values.items():
186+
sample_stat = getattr(samples, statistic)()
187+
for i in range(batch):
188+
assert np.allclose(
189+
sample_stat[i, ...], true_value, atol=0.1
190+
), f"{statistic}={sample_stat[i, ...]} vs. true={true_value}"

0 commit comments

Comments
 (0)