From ed6ce20ba3087f526e6336396cc06ace19b649e9 Mon Sep 17 00:00:00 2001 From: arincbulgur Date: Thu, 15 Dec 2022 16:40:06 -0500 Subject: [PATCH 1/3] Pass warn argument downstream in sinkhorn2 method. --- ot/bregman.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index 4e1a25c1d..786264e89 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -320,15 +320,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) @@ -341,15 +344,18 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, if method.lower() == 'sinkhorn': return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_log': return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) else: raise ValueError("Unknown method '%s'." % method) From 7dbd27ad429f73d1429e6fa1ec9fefff3dfdd216 Mon Sep 17 00:00:00 2001 From: arincbulgur Date: Thu, 15 Dec 2022 19:49:36 -0500 Subject: [PATCH 2/3] releases.md --- RELEASES.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 9cfdd352a..0e7d96f5c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -30,8 +30,9 @@ roughly 2^31) (PR #381) - Fixed an issue where the doc could not be built due to some changes in matplotlib's API (Issue #403, PR #402) - Replaced Numpy C Compiler with Setuptools C Compiler due to deprecation issues (Issue #408, PR #409) - Fixed weak optimal transport docstring (Issue #404, PR #410) -- Fixed error whith parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, +- Fixed error with parameter `log=True`for `SinkhornLpl1Transport` (Issue #412, PR #413) +- Fixed an issue about `warn` parameter in `sinkhorn2` (PR #417) ## 0.8.2 From 71a28f76a67c54e69db67d899c290e0a12ca2bd8 Mon Sep 17 00:00:00 2001 From: arincbulgur Date: Thu, 22 Dec 2022 22:59:08 -0500 Subject: [PATCH 3/3] Fix unittest. --- test/test_bregman.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 0f47c3f17..ce1564225 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -7,6 +7,7 @@ # # License: MIT License +import warnings from itertools import product import numpy as np @@ -58,7 +59,10 @@ def test_convergence_warning(method): with pytest.warns(UserWarning): ot.barycenter(A, M, 1, method=method, stopThr=0, numItermax=1) with pytest.warns(UserWarning): - ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1) + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=True) + with warnings.catch_warnings(): + warnings.simplefilter("error") + ot.sinkhorn2(a1, a2, M, 1, method=method, stopThr=0, numItermax=1, warn=False) def test_not_implemented_method():