Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ roughly 2^31) (PR #381)
- Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402)
- Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409)
- Fixed weak optimal transport docstring (Issue #404, PR #410)

- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412,
PR #413)

## 0.8.2

Expand Down
13 changes: 10 additions & 3 deletions ot/da.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
W = nx.zeros(M.shape, type_as=M)
for cpt in range(numItermax):
Mreg = M + eta * W
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
if log:
transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr, log=True)
else:
transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax,
stopThr=stopInnerThr)
# the transport has been computed. Check if classes are really
# separated
W = nx.ones(M.shape, type_as=M)
Expand All @@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
majs = p * ((majs + epsilon) ** (p - 1))
W[indices_labels[i]] = majs

return transp
if log:
return transp, log
else:
return transp


def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10,
Expand Down
62 changes: 41 additions & 21 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,34 @@ def test_class_jax_tf():
otda.fit(Xs=Xs, ys=ys, Xt=Xt)


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])
def test_log_da(nx, class_to_test):

ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)

Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt)

otda = class_to_test(log=True)

# test its computed
otda.fit(Xs=Xs, ys=ys, Xt=Xt)
assert hasattr(otda, "log_")


@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_sinkhorn_lpl1_transport_class(nx):
"""test_sinkhorn_transport
"""

ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -136,7 +156,7 @@ def test_sinkhorn_l1l2_transport_class(nx):
"""

ns = 50
nt = 100
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -230,8 +250,8 @@ def test_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""

ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -323,8 +343,8 @@ def test_unbalanced_sinkhorn_transport_class(nx):
"""test_sinkhorn_transport
"""

ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -402,8 +422,8 @@ def test_emd_transport_class(nx):
"""test_sinkhorn_transport
"""

ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -558,8 +578,8 @@ def test_mapping_transport_class_specific_seed(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping(nx):
ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand All @@ -579,8 +599,8 @@ def test_linear_mapping(nx):
@pytest.skip_backend("jax")
@pytest.skip_backend("tf")
def test_linear_mapping_class(nx):
ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down Expand Up @@ -609,9 +629,9 @@ def test_jcpot_transport_class(nx):
"""test_jcpot_transport
"""

ns1 = 150
ns2 = 150
nt = 200
ns1 = 50
ns2 = 50
nt = 50

Xs1, ys1 = make_data_classif('3gauss', ns1)
Xs2, ys2 = make_data_classif('3gauss', ns2)
Expand Down Expand Up @@ -681,9 +701,9 @@ def test_jcpot_barycenter(nx):
"""test_jcpot_barycenter
"""

ns1 = 150
ns2 = 150
nt = 200
ns1 = 50
ns2 = 50
nt = 50

sigma = 0.1
np.random.seed(1985)
Expand Down Expand Up @@ -713,8 +733,8 @@ def test_jcpot_barycenter(nx):
def test_emd_laplace_class(nx):
"""test_emd_laplace_transport
"""
ns = 150
nt = 200
ns = 50
nt = 50

Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
Expand Down