@@ -75,6 +75,41 @@ def test_gromov(nx):
75
75
q , Gb .sum (0 ), atol = 1e-04 ) # cf convergence gromov
76
76
77
77
78
+ def test_gromov_dtype_device (nx ):
79
+ # setup
80
+ n_samples = 50 # nb samples
81
+
82
+ mu_s = np .array ([0 , 0 ])
83
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
84
+
85
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 4 )
86
+
87
+ xt = xs [::- 1 ].copy ()
88
+
89
+ p = ot .unif (n_samples )
90
+ q = ot .unif (n_samples )
91
+
92
+ C1 = ot .dist (xs , xs )
93
+ C2 = ot .dist (xt , xt )
94
+
95
+ C1 /= C1 .max ()
96
+ C2 /= C2 .max ()
97
+
98
+ for tp in nx .__type_list__ :
99
+ print (nx .dtype_device (tp ))
100
+
101
+ C1b = nx .from_numpy (C1 , type_as = tp )
102
+ C2b = nx .from_numpy (C2 , type_as = tp )
103
+ pb = nx .from_numpy (p , type_as = tp )
104
+ qb = nx .from_numpy (q , type_as = tp )
105
+
106
+ Gb = ot .gromov .gromov_wasserstein (C1b , C2b , pb , qb , 'square_loss' , verbose = True )
107
+ gw_valb = ot .gromov .gromov_wasserstein2 (C1b , C2b , pb , qb , 'kl_loss' , log = False )
108
+
109
+ nx .assert_same_dtype_device (C1b , Gb )
110
+ nx .assert_same_dtype_device (C1b , gw_valb )
111
+
112
+
78
113
def test_gromov2_gradients ():
79
114
n_samples = 50 # nb samples
80
115
@@ -168,6 +203,46 @@ def test_entropic_gromov(nx):
168
203
q , Gb .sum (0 ), atol = 1e-04 ) # cf convergence gromov
169
204
170
205
206
+ @pytest .skip_backend ("jax" , reason = "test very slow with jax backend" )
207
+ def test_entropic_gromov_dtype_device (nx ):
208
+ # setup
209
+ n_samples = 50 # nb samples
210
+
211
+ mu_s = np .array ([0 , 0 ])
212
+ cov_s = np .array ([[1 , 0 ], [0 , 1 ]])
213
+
214
+ xs = ot .datasets .make_2D_samples_gauss (n_samples , mu_s , cov_s , random_state = 42 )
215
+
216
+ xt = xs [::- 1 ].copy ()
217
+
218
+ p = ot .unif (n_samples )
219
+ q = ot .unif (n_samples )
220
+
221
+ C1 = ot .dist (xs , xs )
222
+ C2 = ot .dist (xt , xt )
223
+
224
+ C1 /= C1 .max ()
225
+ C2 /= C2 .max ()
226
+
227
+ for tp in nx .__type_list__ :
228
+ print (nx .dtype_device (tp ))
229
+
230
+ C1b = nx .from_numpy (C1 , type_as = tp )
231
+ C2b = nx .from_numpy (C2 , type_as = tp )
232
+ pb = nx .from_numpy (p , type_as = tp )
233
+ qb = nx .from_numpy (q , type_as = tp )
234
+
235
+ Gb = ot .gromov .entropic_gromov_wasserstein (
236
+ C1b , C2b , pb , qb , 'square_loss' , epsilon = 5e-4 , verbose = True
237
+ )
238
+ gw_valb = ot .gromov .entropic_gromov_wasserstein2 (
239
+ C1b , C2b , pb , qb , 'square_loss' , epsilon = 5e-4 , verbose = True
240
+ )
241
+
242
+ nx .assert_same_dtype_device (C1b , Gb )
243
+ nx .assert_same_dtype_device (C1b , gw_valb )
244
+
245
+
171
246
def test_pointwise_gromov (nx ):
172
247
n_samples = 50 # nb samples
173
248
0 commit comments