Skip to content

Commit 698c1aa

Browse files
Kilian FatrasKilian Fatras
authored andcommitted
replaced marginal tests
1 parent 104d210 commit 698c1aa

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

test/test_stochastic.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def test_sag_asgd_sinkhorn():
9797

9898
x = rng.randn(n, 2)
9999
u = ot.utils.unif(n)
100-
zero = np.zeros(n)
101100
M = ot.dist(x, x)
102101

103102
G_asgd = ot.stochastic.solve_semi_dual_entropic(u, u, M, reg, "asgd",
@@ -108,13 +107,13 @@ def test_sag_asgd_sinkhorn():
108107

109108
# check constratints
110109
np.testing.assert_allclose(
111-
zero, (G_sag - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sag
110+
G_sag.sum(1), G_sinkhorn.sum(1), atol=1e-03)
112111
np.testing.assert_allclose(
113-
zero, (G_sag - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sag
112+
G_sag.sum(0), G_sinkhorn.sum(0), atol=1e-03)
114113
np.testing.assert_allclose(
115-
zero, (G_asgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence asgd
114+
G_asgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
116115
np.testing.assert_allclose(
117-
zero, (G_asgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence asgd
116+
G_asgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
118117
np.testing.assert_allclose(
119118
G_sag, G_sinkhorn, atol=1e-03) # cf convergence sag
120119
np.testing.assert_allclose(
@@ -175,7 +174,6 @@ def test_dual_sgd_sinkhorn():
175174
# Test uniform
176175
x = rng.randn(n, 2)
177176
u = ot.utils.unif(n)
178-
zero = np.zeros(n)
179177
M = ot.dist(x, x)
180178

181179
G_sgd = ot.stochastic.solve_dual_entropic(u, u, M, reg, batch_size,
@@ -185,17 +183,16 @@ def test_dual_sgd_sinkhorn():
185183

186184
# check constratints
187185
np.testing.assert_allclose(
188-
zero, abs(G_sgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sgd
186+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
189187
np.testing.assert_allclose(
190-
zero, abs(G_sgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sgd
188+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
191189
np.testing.assert_allclose(
192190
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd
193191

194192
# Test gaussian
195193
n = 30
196194
reg = 1
197195
batch_size = 30
198-
zero = np.zeros(n)
199196

200197
a = ot.datasets.make_1D_gauss(n, 15, 5) # m= mean, s= std
201198
b = ot.datasets.make_1D_gauss(n, 15, 5)
@@ -211,8 +208,8 @@ def test_dual_sgd_sinkhorn():
211208

212209
# check constratints
213210
np.testing.assert_allclose(
214-
zero, abs(G_sgd - G_sinkhorn).sum(1), atol=1e-03) # cf convergence sgd
211+
G_sgd.sum(1), G_sinkhorn.sum(1), atol=1e-03)
215212
np.testing.assert_allclose(
216-
zero, abs(G_sgd - G_sinkhorn).sum(0), atol=1e-03) # cf convergence sgd
213+
G_sgd.sum(0), G_sinkhorn.sum(0), atol=1e-03)
217214
np.testing.assert_allclose(
218215
G_sgd, G_sinkhorn, atol=1e-03) # cf convergence sgd

0 commit comments

Comments
 (0)