@@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
73
73
>>> a=[.5, .5]
74
74
>>> b=[.5, .5]
75
75
>>> M=[[0., 1.], [1., 0.]]
76
- >>> ot.sinkhorn2(a, b, M, 1, 1)
77
- array([0.26894142])
76
+ >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
77
+ array([[0.51122823, 0.18807035],
78
+ [0.18807035, 0.51122823]])
78
79
79
80
80
81
References
@@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
91
92
92
93
See Also
93
94
--------
94
- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
95
- ot.unbalanced.sinkhorn_stabilized : Unbalanced Stabilized sinkhorn [9][10]
96
- ot.unbalanced.sinkhorn_epsilon_scaling : Unbalanced Sinkhorn with epslilon scaling [9][10]
95
+ ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
96
+ ot.unbalanced.sinkhorn_stabilized_unbalanced : Unbalanced Stabilized sinkhorn [9][10]
97
+ ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced : Unbalanced Sinkhorn with epslilon scaling [9][10]
97
98
98
99
"""
99
100
100
101
if method .lower () == 'sinkhorn' :
101
102
def sink ():
102
- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
103
- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
104
- else :
105
- warnings .warn ('Unknown method. Falling back to classic Sinkhorn Knopp' )
103
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
104
+ numItermax = numItermax ,
105
+ stopThr = stopThr , verbose = verbose ,
106
+ log = log , ** kwargs )
107
+
108
+ elif method .lower () in ['sinkhorn_stabilized' , 'sinkhorn_epsilon_scaling' ]:
109
+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
106
110
107
111
def sink ():
108
- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
109
- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
112
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
113
+ numItermax = numItermax ,
114
+ stopThr = stopThr , verbose = verbose ,
115
+ log = log , ** kwargs )
116
+ else :
117
+ raise ValueError ('Unknown method. Using classic Sinkhorn Knopp' )
110
118
111
119
return sink ()
112
120
113
121
114
- def sinkhorn2 (a , b , M , reg , alpha , method = 'sinkhorn' , numItermax = 1000 ,
115
- stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
122
+ def sinkhorn_unbalanced2 (a , b , M , reg , alpha , method = 'sinkhorn' ,
123
+ numItermax = 1000 , stopThr = 1e-9 , verbose = False ,
124
+ log = False , ** kwargs ):
116
125
u"""
117
126
Solve the entropic regularization unbalanced optimal transport problem and return the loss
118
127
@@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
173
182
>>> a=[.5, .10]
174
183
>>> b=[.5, .5]
175
184
>>> M=[[0., 1.],[1., 0.]]
176
- >>> ot.sinkhorn2 (a, b, M, 1., 1.)
177
- array([ 0.26894142 ])
185
+ >>> ot.unbalanced.sinkhorn_unbalanced2 (a, b, M, 1., 1.)
186
+ array([0.31912866 ])
178
187
179
188
180
189
@@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
199
208
200
209
if method .lower () == 'sinkhorn' :
201
210
def sink ():
202
- return sinkhorn_knopp (a , b , M , reg , alpha , numItermax = numItermax ,
203
- stopThr = stopThr , verbose = verbose , log = log , ** kwargs )
204
- else :
205
- warnings .warn ('Unknown method using classic Sinkhorn Knopp' )
211
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
212
+ numItermax = numItermax ,
213
+ stopThr = stopThr , verbose = verbose ,
214
+ log = log , ** kwargs )
215
+
216
+ elif method .lower () in ['sinkhorn_stabilized' , 'sinkhorn_epsilon_scaling' ]:
217
+ warnings .warn ('Method not implemented yet. Using classic Sinkhorn Knopp' )
206
218
207
219
def sink ():
208
- return sinkhorn_knopp (a , b , M , reg , alpha , ** kwargs )
220
+ return sinkhorn_knopp_unbalanced (a , b , M , reg , alpha ,
221
+ numItermax = numItermax ,
222
+ stopThr = stopThr , verbose = verbose ,
223
+ log = log , ** kwargs )
224
+ else :
225
+ raise ValueError ('Unknown method. Using classic Sinkhorn Knopp' )
209
226
210
227
b = np .asarray (b , dtype = np .float64 )
211
228
if len (b .shape ) < 2 :
212
- b = b [None , : ]
229
+ b = b [:, None ]
213
230
214
231
return sink ()
215
232
216
233
217
- def sinkhorn_knopp (a , b , M , reg , alpha , numItermax = 1000 ,
218
- stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
234
+ def sinkhorn_knopp_unbalanced (a , b , M , reg , alpha , numItermax = 1000 ,
235
+ stopThr = 1e-9 , verbose = False , log = False , ** kwargs ):
219
236
"""
220
237
Solve the entropic regularization unbalanced optimal transport problem and return the loss
221
238
@@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
273
290
>>> a=[.5, .15]
274
291
>>> b=[.5, .5]
275
292
>>> M=[[0., 1.],[1., 0.]]
276
- >>> ot.sinkhorn(a, b, M, 1., 1.)
277
- array([[ 0.36552929, 0.13447071],
278
- [ 0.13447071, 0.36552929]])
279
-
293
+ >>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
294
+ array([[0.52761554, 0.22392482],
295
+ [0.10286295, 0.32257641]])
280
296
281
297
References
282
298
----------
@@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
303
319
if len (b ) == 0 :
304
320
b = np .ones (n_b , dtype = np .float64 ) / n_b
305
321
306
- assert n_a == len (a ) and n_b == len (b )
307
- if b .ndim > 1 :
322
+ if len (b .shape ) > 1 :
308
323
n_hists = b .shape [1 ]
309
324
else :
310
325
n_hists = 0
@@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
315
330
# we assume that no distances are null except those of the diagonal of
316
331
# distances
317
332
if n_hists :
318
- u = np .ones ((n_a , n_hists )) / n_a
333
+ u = np .ones ((n_a , 1 )) / n_a
319
334
v = np .ones ((n_b , n_hists )) / n_b
335
+ a = a .reshape (n_a , 1 )
320
336
else :
321
337
u = np .ones (n_a ) / n_a
322
338
v = np .ones (n_b ) / n_b
@@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
332
348
333
349
cpt = 0
334
350
err = 1.
351
+
335
352
while (err > stopThr and cpt < numItermax ):
336
353
uprev = u
337
354
vprev = v
@@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
473
490
or np .any (np .isinf (u )) or np .any (np .isinf (v ))):
474
491
# we have reached the machine precision
475
492
# come back to previous solution and quit loop
476
- warnings .warn ('Numerical errors at iteration' , cpt )
493
+ warnings .warn ('Numerical errors at iteration %s' % cpt )
477
494
u = uprev
478
495
v = vprev
479
496
break
0 commit comments