Skip to content

Commit 8979827

Browse files
author
Hicham Janati
committed
fix func names + add more tests
1 parent 50bc900 commit 8979827

File tree

4 files changed

+127
-35
lines changed

4 files changed

+127
-35
lines changed

ot/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# OT functions
2626
from .lp import emd, emd2
2727
from .bregman import sinkhorn, sinkhorn2, barycenter
28-
from .unbalanced import sinkhorn_unbalanced
28+
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced
2929
from .da import sinkhorn_lpl1_mm
3030

3131
# utils functions

ot/bregman.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def sink():
241241

242242
b = np.asarray(b, dtype=np.float64)
243243
if len(b.shape) < 2:
244-
b = b.reshape((-1, 1))
244+
b = b[:, None]
245245

246246
return sink()
247247

ot/unbalanced.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
7373
>>> a=[.5, .5]
7474
>>> b=[.5, .5]
7575
>>> M=[[0., 1.], [1., 0.]]
76-
>>> ot.sinkhorn2(a, b, M, 1, 1)
77-
array([0.26894142])
76+
>>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
77+
array([[0.51122823, 0.18807035],
78+
[0.18807035, 0.51122823]])
7879
7980
8081
References
@@ -91,28 +92,36 @@ def sinkhorn_unbalanced(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
9192
9293
See Also
9394
--------
94-
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn [10]
95-
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn [9][10]
96-
ot.unbalanced.sinkhorn_epsilon_scaling: Unbalanced Sinkhorn with epslilon scaling [9][10]
95+
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn [10]
96+
ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn [9][10]
97+
ot.unbalanced.sinkhorn_epsilon_scaling_unbalanced: Unbalanced Sinkhorn with epslilon scaling [9][10]
9798
9899
"""
99100

100101
if method.lower() == 'sinkhorn':
101102
def sink():
102-
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
103-
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
104-
else:
105-
warnings.warn('Unknown method. Falling back to classic Sinkhorn Knopp')
103+
return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
104+
numItermax=numItermax,
105+
stopThr=stopThr, verbose=verbose,
106+
log=log, **kwargs)
107+
108+
elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
109+
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
106110

107111
def sink():
108-
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
109-
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
112+
return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
113+
numItermax=numItermax,
114+
stopThr=stopThr, verbose=verbose,
115+
log=log, **kwargs)
116+
else:
117+
raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
110118

111119
return sink()
112120

113121

114-
def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
115-
stopThr=1e-9, verbose=False, log=False, **kwargs):
122+
def sinkhorn_unbalanced2(a, b, M, reg, alpha, method='sinkhorn',
123+
numItermax=1000, stopThr=1e-9, verbose=False,
124+
log=False, **kwargs):
116125
u"""
117126
Solve the entropic regularization unbalanced optimal transport problem and return the loss
118127
@@ -173,8 +182,8 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
173182
>>> a=[.5, .10]
174183
>>> b=[.5, .5]
175184
>>> M=[[0., 1.],[1., 0.]]
176-
>>> ot.sinkhorn2(a, b, M, 1., 1.)
177-
array([ 0.26894142])
185+
>>> ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.)
186+
array([0.31912866])
178187
179188
180189
@@ -199,23 +208,31 @@ def sinkhorn2(a, b, M, reg, alpha, method='sinkhorn', numItermax=1000,
199208

200209
if method.lower() == 'sinkhorn':
201210
def sink():
202-
return sinkhorn_knopp(a, b, M, reg, alpha, numItermax=numItermax,
203-
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
204-
else:
205-
warnings.warn('Unknown method using classic Sinkhorn Knopp')
211+
return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
212+
numItermax=numItermax,
213+
stopThr=stopThr, verbose=verbose,
214+
log=log, **kwargs)
215+
216+
elif method.lower() in ['sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']:
217+
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
206218

207219
def sink():
208-
return sinkhorn_knopp(a, b, M, reg, alpha, **kwargs)
220+
return sinkhorn_knopp_unbalanced(a, b, M, reg, alpha,
221+
numItermax=numItermax,
222+
stopThr=stopThr, verbose=verbose,
223+
log=log, **kwargs)
224+
else:
225+
raise ValueError('Unknown method. Using classic Sinkhorn Knopp')
209226

210227
b = np.asarray(b, dtype=np.float64)
211228
if len(b.shape) < 2:
212-
b = b[None, :]
229+
b = b[:, None]
213230

214231
return sink()
215232

216233

217-
def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
218-
stopThr=1e-9, verbose=False, log=False, **kwargs):
234+
def sinkhorn_knopp_unbalanced(a, b, M, reg, alpha, numItermax=1000,
235+
stopThr=1e-9, verbose=False, log=False, **kwargs):
219236
"""
220237
Solve the entropic regularization unbalanced optimal transport problem and return the loss
221238
@@ -273,10 +290,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
273290
>>> a=[.5, .15]
274291
>>> b=[.5, .5]
275292
>>> M=[[0., 1.],[1., 0.]]
276-
>>> ot.sinkhorn(a, b, M, 1., 1.)
277-
array([[ 0.36552929, 0.13447071],
278-
[ 0.13447071, 0.36552929]])
279-
293+
>>> ot.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
294+
array([[0.52761554, 0.22392482],
295+
[0.10286295, 0.32257641]])
280296
281297
References
282298
----------
@@ -303,8 +319,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
303319
if len(b) == 0:
304320
b = np.ones(n_b, dtype=np.float64) / n_b
305321

306-
assert n_a == len(a) and n_b == len(b)
307-
if b.ndim > 1:
322+
if len(b.shape) > 1:
308323
n_hists = b.shape[1]
309324
else:
310325
n_hists = 0
@@ -315,8 +330,9 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
315330
# we assume that no distances are null except those of the diagonal of
316331
# distances
317332
if n_hists:
318-
u = np.ones((n_a, n_hists)) / n_a
333+
u = np.ones((n_a, 1)) / n_a
319334
v = np.ones((n_b, n_hists)) / n_b
335+
a = a.reshape(n_a, 1)
320336
else:
321337
u = np.ones(n_a) / n_a
322338
v = np.ones(n_b) / n_b
@@ -332,6 +348,7 @@ def sinkhorn_knopp(a, b, M, reg, alpha, numItermax=1000,
332348

333349
cpt = 0
334350
err = 1.
351+
335352
while (err > stopThr and cpt < numItermax):
336353
uprev = u
337354
vprev = v
@@ -473,7 +490,7 @@ def barycenter_unbalanced(A, M, reg, alpha, weights=None, numItermax=1000,
473490
or np.any(np.isinf(u)) or np.any(np.isinf(v))):
474491
# we have reached the machine precision
475492
# come back to previous solution and quit loop
476-
warnings.warn('Numerical errors at iteration', cpt)
493+
warnings.warn('Numerical errors at iteration %s' % cpt)
477494
u = uprev
478495
v = vprev
479496
break

test/test_unbalanced.py

Lines changed: 77 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def test_unbalanced_convergence(method):
2929
G, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon, alpha=alpha,
3030
stopThr=1e-10, method=method,
3131
log=True)
32-
32+
loss = ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
33+
method=method)
3334
# check fixed point equations
3435
fi = alpha / (alpha + epsilon)
3536
v_final = (b / K.T.dot(log["u"])) ** fi
@@ -40,6 +41,44 @@ def test_unbalanced_convergence(method):
4041
np.testing.assert_allclose(
4142
v_final, log["v"], atol=1e-05)
4243

44+
# check if sinkhorn_unbalanced2 returns the correct loss
45+
np.testing.assert_allclose((G * M).sum(), loss, atol=1e-5)
46+
47+
48+
@pytest.mark.parametrize("method", ["sinkhorn"])
49+
def test_unbalanced_multiple_inputs(method):
50+
# test generalized sinkhorn for unbalanced OT
51+
n = 100
52+
rng = np.random.RandomState(42)
53+
54+
x = rng.randn(n, 2)
55+
a = ot.utils.unif(n)
56+
57+
# make dists unbalanced
58+
b = rng.rand(n, 2)
59+
60+
M = ot.dist(x, x)
61+
epsilon = 1.
62+
alpha = 1.
63+
K = np.exp(- M / epsilon)
64+
65+
loss, log = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg=epsilon,
66+
alpha=alpha,
67+
stopThr=1e-10, method=method,
68+
log=True)
69+
# check fixed point equations
70+
fi = alpha / (alpha + epsilon)
71+
v_final = (b / K.T.dot(log["u"])) ** fi
72+
73+
u_final = (a[:, None] / K.dot(log["v"])) ** fi
74+
75+
np.testing.assert_allclose(
76+
u_final, log["u"], atol=1e-05)
77+
np.testing.assert_allclose(
78+
v_final, log["v"], atol=1e-05)
79+
80+
assert len(loss) == b.shape[1]
81+
4382

4483
def test_unbalanced_barycenter():
4584
# test generalized sinkhorn for unbalanced OT barycenter
@@ -59,7 +98,6 @@ def test_unbalanced_barycenter():
5998
q, log = ot.unbalanced.barycenter_unbalanced(A, M, reg=epsilon, alpha=alpha,
6099
stopThr=1e-10,
61100
log=True)
62-
63101
# check fixed point equations
64102
fi = alpha / (alpha + epsilon)
65103
v_final = (q[:, None] / K.T.dot(log["u"])) ** fi
@@ -69,3 +107,40 @@ def test_unbalanced_barycenter():
69107
u_final, log["u"], atol=1e-05)
70108
np.testing.assert_allclose(
71109
v_final, log["v"], atol=1e-05)
110+
111+
112+
def test_implemented_methods():
113+
IMPLEMENTED_METHODS = ['sinkhorn']
114+
TO_BE_IMPLEMENTED_METHODS = ['sinkhorn_stabilized',
115+
'sinkhorn_epsilon_scaling']
116+
NOT_VALID_TOKENS = ['foo']
117+
# test generalized sinkhorn for unbalanced OT barycenter
118+
n = 3
119+
rng = np.random.RandomState(42)
120+
121+
x = rng.randn(n, 2)
122+
a = ot.utils.unif(n)
123+
124+
# make dists unbalanced
125+
b = ot.utils.unif(n) * 1.5
126+
127+
M = ot.dist(x, x)
128+
epsilon = 1.
129+
alpha = 1.
130+
for method in IMPLEMENTED_METHODS:
131+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
132+
method=method)
133+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
134+
method=method)
135+
with pytest.warns(UserWarning, match='not implemented'):
136+
for method in set(TO_BE_IMPLEMENTED_METHODS):
137+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
138+
method=method)
139+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
140+
method=method)
141+
with pytest.raises(ValueError):
142+
for method in set(NOT_VALID_TOKENS):
143+
ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha,
144+
method=method)
145+
ot.unbalanced.sinkhorn_unbalanced2(a, b, M, epsilon, alpha,
146+
method=method)

0 commit comments

Comments
 (0)