From c9a7735c1df39ad8ed03cfca0e8fc85cc9b70c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 6 Dec 2022 14:50:29 +0100 Subject: [PATCH 1/5] correct bug in DA l1lp with log --- ot/da.py | 13 ++++++++++--- test/test_da.py | 19 +++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/ot/da.py b/ot/da.py index 0b9737e0e..083663cca 100644 --- a/ot/da.py +++ b/ot/da.py @@ -126,8 +126,12 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, W = nx.zeros(M.shape, type_as=M) for cpt in range(numItermax): Mreg = M + eta * W - transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, - stopThr=stopInnerThr) + if log: + transp, log = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr, log=True) + else: + transp = sinkhorn(a, b, Mreg, reg, numItermax=numInnerItermax, + stopThr=stopInnerThr) # the transport has been computed. Check if classes are really # separated W = nx.ones(M.shape, type_as=M) @@ -136,7 +140,10 @@ def sinkhorn_lpl1_mm(a, labels_a, b, M, reg, eta=0.1, numItermax=10, majs = p * ((majs + epsilon) ** (p - 1)) W[indices_labels[i]] = majs - return transp + if log: + return transp, log + else: + return transp def sinkhorn_l1l2_gl(a, labels_a, b, M, reg, eta=0.1, numItermax=10, diff --git a/test/test_da.py b/test/test_da.py index 4bf0ab11f..c9b1b5e97 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -42,6 +42,25 @@ def test_class_jax_tf(): otda.fit(Xs=Xs, ys=ys, Xt=Xt) +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.EMDLaplaceTransport]) +def test_log_da(nx, class_to_test): + + ns = 150 + nt = 200 + + Xs, ys = make_data_classif('3gauss', ns) + Xt, yt = make_data_classif('3gauss2', nt) + + Xs, ys, Xt, yt = nx.from_numpy(Xs, ys, Xt, yt) + + otda = class_to_test(log=True) + + # test its computed + otda.fit(Xs=Xs, ys=ys, Xt=Xt) + assert hasattr(otda, "cost_") + assert hasattr(otda, "coupling_") + + @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_sinkhorn_lpl1_transport_class(nx): From d2fd0a20e76de1892009d444d1a81896f1450b2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 6 Dec 2022 15:01:56 +0100 Subject: [PATCH 2/5] better tests and speedup with smaller dataset size --- RELEASES.md | 3 ++- test/test_da.py | 51 ++++++++++++++++++++++++------------------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 68487e8c1..3bd84c1c1 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -27,7 +27,8 @@ roughly 2^31) (PR #381) - Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402) - Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409) - Fixed weak optimal transport docstring (Issue #404, PR #410) - +- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, +PR #413) ## 0.8.2 diff --git a/test/test_da.py b/test/test_da.py index c9b1b5e97..fe80e44f8 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -42,11 +42,11 @@ def test_class_jax_tf(): otda.fit(Xs=Xs, ys=ys, Xt=Xt) -@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.EMDLaplaceTransport]) +@pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) def test_log_da(nx, class_to_test): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -57,8 +57,7 @@ def test_log_da(nx, class_to_test): # test its computed otda.fit(Xs=Xs, ys=ys, Xt=Xt) - assert hasattr(otda, "cost_") - assert hasattr(otda, "coupling_") + assert hasattr(otda, "log_") @pytest.skip_backend("jax") @@ -67,8 +66,8 @@ def test_sinkhorn_lpl1_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -155,7 +154,7 @@ def test_sinkhorn_l1l2_transport_class(nx): """ ns = 50 - nt = 100 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -249,8 +248,8 @@ def test_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -342,8 +341,8 @@ def test_unbalanced_sinkhorn_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -421,8 +420,8 @@ def test_emd_transport_class(nx): """test_sinkhorn_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -577,8 +576,8 @@ def test_mapping_transport_class_specific_seed(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -598,8 +597,8 @@ def test_linear_mapping(nx): @pytest.skip_backend("jax") @pytest.skip_backend("tf") def test_linear_mapping_class(nx): - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) @@ -628,9 +627,9 @@ def test_jcpot_transport_class(nx): """test_jcpot_transport """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 Xs1, ys1 = make_data_classif('3gauss', ns1) Xs2, ys2 = make_data_classif('3gauss', ns2) @@ -700,9 +699,9 @@ def test_jcpot_barycenter(nx): """test_jcpot_barycenter """ - ns1 = 150 - ns2 = 150 - nt = 200 + ns1 = 50 + ns2 = 50 + nt = 50 sigma = 0.1 np.random.seed(1985) @@ -732,8 +731,8 @@ def test_jcpot_barycenter(nx): def test_emd_laplace_class(nx): """test_emd_laplace_transport """ - ns = 150 - nt = 200 + ns = 50 + nt = 50 Xs, ys = make_data_classif('3gauss', ns) Xt, yt = make_data_classif('3gauss2', nt) From 9a4293e5f2fe6cae4db21c08d719a669ae6a0b9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 6 Dec 2022 16:26:38 +0100 Subject: [PATCH 3/5] remove jax for log test --- test/test_da.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_da.py b/test/test_da.py index fe80e44f8..820017791 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -41,7 +41,7 @@ def test_class_jax_tf(): with pytest.raises(TypeError): otda.fit(Xs=Xs, ys=ys, Xt=Xt) - +@pytest.skip_backend("jax") @pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) def test_log_da(nx, class_to_test): From 4f2c92f81be53c0c02fcc655326f24bbed2a74c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 6 Dec 2022 17:02:41 +0100 Subject: [PATCH 4/5] remove trndorflow for log test --- test/test_da.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_da.py b/test/test_da.py index 820017791..6b34f5acd 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -42,6 +42,7 @@ def test_class_jax_tf(): otda.fit(Xs=Xs, ys=ys, Xt=Xt) @pytest.skip_backend("jax") +@pytest.skip_backend("tf") @pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport]) def test_log_da(nx, class_to_test): From c5c6380b1f574d4e62d236eddd19e33d413f4ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 6 Dec 2022 17:05:58 +0100 Subject: [PATCH 5/5] pep8! --- test/test_da.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_da.py b/test/test_da.py index 6b34f5acd..138936f19 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -41,6 +41,7 @@ def test_class_jax_tf(): with pytest.raises(TypeError): otda.fit(Xs=Xs, ys=ys, Xt=Xt) + @pytest.skip_backend("jax") @pytest.skip_backend("tf") @pytest.mark.parametrize("class_to_test", [ot.da.EMDTransport, ot.da.SinkhornTransport, ot.da.SinkhornLpl1Transport, ot.da.SinkhornL1l2Transport, ot.da.SinkhornL1l2Transport])