Skip to content

Commit 896b235

Browse files
committed
Silence std::cerr calls inside of python
add use_cerr flag to backend add use_cerr to lp.__init__ and passthrough to c++ add use_cerr defs to pyx fix call to use_cerr to update from ot.backend def use_cerr in emd.h emd wrapper cerr cpp _use_cerr releases.md
1 parent d6bf10d commit 896b235

File tree

8 files changed

+43
-24
lines changed

8 files changed

+43
-24
lines changed

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
1212
(Issue #371, PR #373)
13+
- Fixed an issue where hitting iteration limits would be reported to stderr by std:cerr regardless of Python's stderr stream status.
1314

1415
## 0.8.2
1516

ot/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@
127127
tf = False
128128
tf_type = float
129129

130+
def get_cerr():
131+
return use_cerr
132+
use_cerr = False
130133

131134
str_type_error = "All array should be from the same type/backend. Current types are : {}"
132135

ot/lp/EMD.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ enum ProblemType {
2828
MAX_ITER_REACHED
2929
};
3030

31-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
32-
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
31+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, bool use_cerr);
32+
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads, bool use_cerr);
3333

3434

3535

ot/lp/EMD_wrapper.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020

2121

2222
int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
23-
double* alpha, double* beta, double *cost, int maxIter) {
23+
double* alpha, double* beta, double *cost, int maxIter,
24+
bool use_cerr=true) {
2425
// beware M and C are stored in row major C style!!!
2526

2627
using namespace lemon;
@@ -54,7 +55,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
5455
std::vector<int> indI(n), indJ(m);
5556
std::vector<double> weights1(n), weights2(m);
5657
Digraph di(n, m);
57-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter);
58+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, use_cerr);
5859

5960
// Set supply and demand, don't account for 0 values (faster)
6061

@@ -122,7 +123,8 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
122123

123124

124125
int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
125-
double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
126+
double* alpha, double* beta, double *cost, int maxIter, int numThreads,
127+
bool use_cerr=true) {
126128
// beware M and C are stored in row major C style!!!
127129

128130
using namespace lemon_omp;
@@ -156,7 +158,7 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
156158
std::vector<int> indI(n), indJ(m);
157159
std::vector<double> weights1(n), weights2(m);
158160
Digraph di(n, m);
159-
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads);
161+
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(di, true, n+m, ((int64_t)n)*((int64_t)m), maxIter, numThreads, use_cerr);
160162

161163
// Set supply and demand, don't account for 0 values (faster)
162164

ot/lp/__init__.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from ..utils import dist, list_to_array
2626
from ..utils import parmap
27-
from ..backend import get_backend
27+
from ..backend import get_backend, get_cerr
2828

2929
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
3030
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter']
@@ -330,7 +330,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
330330

331331
numThreads = check_number_threads(numThreads)
332332

333-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
333+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads,
334+
use_cerr=get_cerr())
334335

335336
if center_dual:
336337
u, v = center_ot_dual(u, v, a, b)
@@ -489,7 +490,8 @@ def emd2(a, b, M, processes=1,
489490
def f(b):
490491
bsel = b != 0
491492

492-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
493+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads,
494+
use_cerr=get_cerr())
493495

494496
if center_dual:
495497
u, v = center_ot_dual(u, v, a, b)
@@ -521,7 +523,8 @@ def f(b):
521523
else:
522524
def f(b):
523525
bsel = b != 0
524-
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads)
526+
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads,
527+
use_cerr=get_cerr())
525528

526529
if center_dual:
527530
u, v = center_ot_dual(u, v, a, b)

ot/lp/emd_wrap.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Cython linker with C solver
99

1010
import numpy as np
1111
cimport numpy as np
12-
12+
from libcpp cimport bool as bool_t
1313
from ..utils import dist
1414

1515
cimport cython
@@ -19,8 +19,8 @@ import warnings
1919

2020

2121
cdef extern from "EMD.h":
22-
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
23-
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
22+
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, bool_t use_cerr) nogil
23+
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads, bool_t use_cerr) nogil
2424
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
2525

2626

@@ -39,7 +39,7 @@ def check_result(result_code):
3939

4040
@cython.boundscheck(False)
4141
@cython.wraparound(False)
42-
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
42+
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads, bool_t use_cerr=False):
4343
"""
4444
Solves the Earth Movers distance problem and returns the optimal transport matrix
4545
@@ -111,9 +111,9 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
111111
# calling the function
112112
with nogil:
113113
if numThreads == 1:
114-
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter)
114+
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, use_cerr)
115115
else:
116-
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
116+
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads, use_cerr)
117117
return G, cost, alpha, beta, result_code
118118

119119

ot/lp/network_simplex_simple.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,13 @@ namespace lemon {
233233
/// mixed order in the internal data structure.
234234
/// In special cases, it could lead to better overall performance,
235235
/// but it is usually slower. Therefore it is disabled by default.
236-
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
236+
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters, bool use_cerr=true) :
237237
_graph(graph), //_arc_id(graph),
238238
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
239239
MAX(std::numeric_limits<Value>::max()),
240240
INF(std::numeric_limits<Value>::has_infinity ?
241-
std::numeric_limits<Value>::infinity() : MAX)
241+
std::numeric_limits<Value>::infinity() : MAX),
242+
_use_cerr(use_cerr)
242243
{
243244
// Reset data structures
244245
reset();
@@ -292,6 +293,7 @@ namespace lemon {
292293

293294

294295
private:
296+
bool _use_cerr;
295297

296298
size_t max_iter;
297299
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -334,6 +336,7 @@ namespace lemon {
334336
IntVector _source; // keep nodes as integers
335337
IntVector _target;
336338
bool _arc_mixing;
339+
337340
public:
338341
// Node and arc data
339342
CostVector _cost;
@@ -1433,8 +1436,11 @@ namespace lemon {
14331436
while (pivot.findEnteringArc()) {
14341437
if(max_iter > 0 && ++iter_number>=max_iter&&max_iter>0){
14351438
char errMess[1000];
1436-
sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
1437-
std::cerr << errMess;
1439+
1440+
if(_use_cerr){
1441+
sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
1442+
std::cerr << errMess;
1443+
}
14381444
retVal = MAX_ITER_REACHED;
14391445
break;
14401446
}

ot/lp/network_simplex_simple_omp.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,13 @@ namespace lemon_omp {
244244
/// mixed order in the internal data structure.
245245
/// In special cases, it could lead to better overall performance,
246246
/// but it is usually slower. Therefore it is disabled by default.
247-
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
247+
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1, bool use_cerr=true) :
248248
_graph(graph), //_arc_id(graph),
249249
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
250250
MAX(std::numeric_limits<Value>::max()),
251251
INF(std::numeric_limits<Value>::has_infinity ?
252-
std::numeric_limits<Value>::infinity() : MAX)
252+
std::numeric_limits<Value>::infinity() : MAX),
253+
_use_cerr(use_cerr)
253254
{
254255
// Reset data structures
255256
reset();
@@ -317,6 +318,7 @@ namespace lemon_omp {
317318

318319

319320
private:
321+
bool _use_cerr;
320322
size_t max_iter;
321323
int num_threads;
322324
TEMPLATE_DIGRAPH_TYPEDEFS(GR);
@@ -1611,8 +1613,10 @@ namespace lemon_omp {
16111613

16121614
} else {
16131615
char errMess[1000];
1614-
sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
1615-
std::cerr << errMess;
1616+
if(_use_cerr){
1617+
sprintf( errMess, "RESULT MIGHT BE INACURATE\nMax number of iteration reached, currently \%d. Sometimes iterations go on in cycle even though the solution has been reached, to check if it's the case here have a look at the minimal reduced cost. If it is very close to machine precision, you might actually have the correct solution, if not try setting the maximum number of iterations a bit higher\n",iter_number );
1618+
std::cerr << errMess;
1619+
}
16161620
retVal = MAX_ITER_REACHED;
16171621
break;
16181622
}

0 commit comments

Comments
 (0)