@@ -15,9 +15,9 @@ def seed():
15
15
@pytest .fixture (scope = "module" )
16
16
def normal_samples (seed ):
17
17
rng = np .random .default_rng (seed = seed )
18
- n_samples = 10 ** 4
18
+ n_samples = 10 ** 3
19
19
loc = 1
20
- scale = np . sqrt ( 2 )
20
+ scale = 1
21
21
return {
22
22
"samples" : rng .normal (loc = loc , scale = scale , size = (n_samples , 15 , 1 )),
23
23
"loc" : loc ,
@@ -28,7 +28,7 @@ def normal_samples(seed):
28
28
@pytest .fixture (scope = "module" )
29
29
def uniform_samples (seed ):
30
30
rng = np .random .default_rng (seed = seed )
31
- n_samples = 10 ** 4
31
+ n_samples = 10 ** 3
32
32
low = 0
33
33
high = 1
34
34
return {
@@ -115,7 +115,9 @@ def kurtosis(samples):
115
115
116
116
assert sample_kurtosis .shape == (batch , y_dim )
117
117
for i in range (batch ):
118
- assert np .allclose (sample_kurtosis [i , :], true_kurtosis , atol = 0.1 )
118
+ assert np .allclose (
119
+ sample_kurtosis [i , :], true_kurtosis , atol = 0.5
120
+ ), f"sample_kurtosis={ sample_kurtosis [i , :]} , true={ true_kurtosis } "
119
121
120
122
121
123
def test_samples_confidence_interval_and_quantiles (normal_samples ):
@@ -134,10 +136,10 @@ def test_samples_confidence_interval_and_quantiles(normal_samples):
134
136
assert sample_q_025 .shape == (1 , samples .batch , samples .y_dim )
135
137
assert sample_confidence_interval .shape == (2 , samples .batch , samples .y_dim )
136
138
for i in range (samples .batch ):
137
- assert np .allclose (sample_q_975 [:, i , :], true_q_975 , atol = 0.1 )
138
- assert np .allclose (sample_q_025 [:, i , :], true_q_025 , atol = 0.1 )
139
+ assert np .allclose (sample_q_975 [:, i , :], true_q_975 , atol = 0.3 )
140
+ assert np .allclose (sample_q_025 [:, i , :], true_q_025 , atol = 0.3 )
139
141
assert np .allclose (
140
- sample_confidence_interval [:, i , :].reshape (- 1 ), true_confidence_interval , atol = 0.1
142
+ sample_confidence_interval [:, i , :].reshape (- 1 ), true_confidence_interval , atol = 0.3
141
143
)
142
144
143
145
@@ -152,35 +154,37 @@ def test_samples_correlation(multivariate_normal_samples):
152
154
assert np .allclose (sample_correlation [i , :, :], true_correlation , atol = 0.1 )
153
155
154
156
155
- @pytest .mark .parametrize (
156
- "statistic, true_value" ,
157
- [
158
- ("sample_mean" , 1 ),
159
- ("sample_median" , 1 ),
160
- ("sample_mode" , 1 ),
161
- ("sample_std" , np .sqrt (2 )),
162
- ],
163
- )
164
- def test_samples_main_statistics (statistic , true_value , normal_samples ):
157
+ def test_samples_main_statistics (normal_samples ):
158
+ true_values = {
159
+ "sample_mean" : normal_samples ["loc" ],
160
+ "sample_median" : normal_samples ["loc" ],
161
+ "sample_mode" : normal_samples ["loc" ],
162
+ "sample_std" : normal_samples ["scale" ],
163
+ }
164
+
165
165
samples = Samples (normal_samples ["samples" ])
166
166
batch = normal_samples ["samples" ].shape [1 ]
167
167
168
- sample_stat = getattr (samples , statistic )()
169
- for i in range (batch ):
170
- assert np .allclose (sample_stat [i , ...], true_value , atol = 0.1 )
168
+ for statistic , true_value in true_values .items ():
169
+ sample_stat = getattr (samples , statistic )()
170
+ for i in range (batch ):
171
+ assert np .allclose (
172
+ sample_stat [i , ...], true_value , atol = 0.1
173
+ ), f"{ statistic } ={ sample_stat [i , ...]} vs. true={ true_value } "
174
+
171
175
176
+ def test_samples_max_min (uniform_samples ):
177
+ true_values = {
178
+ "sample_max" : uniform_samples ["high" ],
179
+ "sample_min" : uniform_samples ["low" ],
180
+ }
172
181
173
- @pytest .mark .parametrize (
174
- "statistic, true_value" ,
175
- [
176
- ("sample_max" , 1 ),
177
- ("sample_min" , 0 ),
178
- ],
179
- )
180
- def test_samples_max_min (statistic , true_value , uniform_samples ):
181
182
samples = Samples (uniform_samples ["samples" ])
182
183
batch = uniform_samples ["samples" ].shape [1 ]
183
184
184
- sample_stat = getattr (samples , statistic )()
185
- for i in range (batch ):
186
- assert np .allclose (sample_stat [i , ...], true_value , atol = 0.1 )
185
+ for statistic , true_value in true_values .items ():
186
+ sample_stat = getattr (samples , statistic )()
187
+ for i in range (batch ):
188
+ assert np .allclose (
189
+ sample_stat [i , ...], true_value , atol = 0.1
190
+ ), f"{ statistic } ={ sample_stat [i , ...]} vs. true={ true_value } "
0 commit comments