Skip to content

Commit f0beebf

Browse files
author
ncassereau
committed
gromov & entropic gromov
1 parent f31c15d commit f0beebf

File tree

1 file changed

+75
-0
lines changed

1 file changed

+75
-0
lines changed

test/test_gromov.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,41 @@ def test_gromov(nx):
7575
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
7676

7777

78+
def test_gromov_dtype_device(nx):
79+
# setup
80+
n_samples = 50 # nb samples
81+
82+
mu_s = np.array([0, 0])
83+
cov_s = np.array([[1, 0], [0, 1]])
84+
85+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4)
86+
87+
xt = xs[::-1].copy()
88+
89+
p = ot.unif(n_samples)
90+
q = ot.unif(n_samples)
91+
92+
C1 = ot.dist(xs, xs)
93+
C2 = ot.dist(xt, xt)
94+
95+
C1 /= C1.max()
96+
C2 /= C2.max()
97+
98+
for tp in nx.__type_list__:
99+
print(nx.dtype_device(tp))
100+
101+
C1b = nx.from_numpy(C1, type_as=tp)
102+
C2b = nx.from_numpy(C2, type_as=tp)
103+
pb = nx.from_numpy(p, type_as=tp)
104+
qb = nx.from_numpy(q, type_as=tp)
105+
106+
Gb = ot.gromov.gromov_wasserstein(C1b, C2b, pb, qb, 'square_loss', verbose=True)
107+
gw_valb = ot.gromov.gromov_wasserstein2(C1b, C2b, pb, qb, 'kl_loss', log=False)
108+
109+
nx.assert_same_dtype_device(C1b, Gb)
110+
nx.assert_same_dtype_device(C1b, gw_valb)
111+
112+
78113
def test_gromov2_gradients():
79114
n_samples = 50 # nb samples
80115

@@ -168,6 +203,46 @@ def test_entropic_gromov(nx):
168203
q, Gb.sum(0), atol=1e-04) # cf convergence gromov
169204

170205

206+
@pytest.skip_backend("jax", reason="test very slow with jax backend")
207+
def test_entropic_gromov_dtype_device(nx):
208+
# setup
209+
n_samples = 50 # nb samples
210+
211+
mu_s = np.array([0, 0])
212+
cov_s = np.array([[1, 0], [0, 1]])
213+
214+
xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42)
215+
216+
xt = xs[::-1].copy()
217+
218+
p = ot.unif(n_samples)
219+
q = ot.unif(n_samples)
220+
221+
C1 = ot.dist(xs, xs)
222+
C2 = ot.dist(xt, xt)
223+
224+
C1 /= C1.max()
225+
C2 /= C2.max()
226+
227+
for tp in nx.__type_list__:
228+
print(nx.dtype_device(tp))
229+
230+
C1b = nx.from_numpy(C1, type_as=tp)
231+
C2b = nx.from_numpy(C2, type_as=tp)
232+
pb = nx.from_numpy(p, type_as=tp)
233+
qb = nx.from_numpy(q, type_as=tp)
234+
235+
Gb = ot.gromov.entropic_gromov_wasserstein(
236+
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
237+
)
238+
gw_valb = ot.gromov.entropic_gromov_wasserstein2(
239+
C1b, C2b, pb, qb, 'square_loss', epsilon=5e-4, verbose=True
240+
)
241+
242+
nx.assert_same_dtype_device(C1b, Gb)
243+
nx.assert_same_dtype_device(C1b, gw_valb)
244+
245+
171246
def test_pointwise_gromov(nx):
172247
n_samples = 50 # nb samples
173248

0 commit comments

Comments
 (0)