Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@

- Fixed an issue where we could not ask TorchBackend to place a random tensor on GPU
(Issue #371, PR #373)
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status.
- Fixed an issue where Sinkhorn solver assumed a symmetric cost matrix (Issue #374, PR #375)
- Fixed an issue where hitting iteration limits would be reported to stderr by std::cerr regardless of Python's stderr stream status (PR #377)
- Fixed an issue where the metric argument in ot.dist did not allow a callable parameter (Issue #378, PR #379)
- Fixed an issue where the max number of iterations in ot.emd was not allow to go beyond 2^31 (PR #380)


## 0.8.2
Expand Down
5 changes: 3 additions & 2 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <iostream>
#include <vector>
#include <cstdint>

typedef unsigned int node_id_type;

Expand All @@ -28,8 +29,8 @@ enum ProblemType {
MAX_ITER_REACHED
};

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads);
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);



Expand Down
4 changes: 2 additions & 2 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter) {
double* alpha, double* beta, double *cost, uint64_t maxIter) {
// beware M and C are stored in row major C style!!!

using namespace lemon;
Expand Down Expand Up @@ -122,7 +122,7 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,


int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, int maxIter, int numThreads) {
double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!

using namespace lemon_omp;
Expand Down
9 changes: 5 additions & 4 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ from ..utils import dist

cimport cython
cimport libc.math as math
from libc.stdint cimport uint64_t

import warnings


cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, int maxIter, int numThreads) nogil
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED


Expand All @@ -39,7 +40,7 @@ def check_result(result_code):

@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, int numThreads):
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, uint64_t max_iter, int numThreads):
"""
Solves the Earth Movers distance problem and returns the optimal transport matrix

Expand Down Expand Up @@ -75,7 +76,7 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
target histogram
M : (ns,nt) numpy.ndarray, float64
loss matrix
max_iter : int
max_iter : uint64_t
The maximum number of iterations before stopping the optimization
algorithm if it has not converged.

Expand Down
8 changes: 4 additions & 4 deletions ot/lp/network_simplex_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ namespace lemon {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters) :
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
Expand All @@ -242,7 +242,7 @@ namespace lemon {
{
// Reset data structures
reset();
max_iter=maxiters;
max_iter = maxiters;
}

/// The type of the flow amounts, capacity bounds and supply values
Expand Down Expand Up @@ -293,7 +293,7 @@ namespace lemon {

private:

size_t max_iter;
uint64_t max_iter;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);

typedef std::vector<int> IntVector;
Expand Down Expand Up @@ -1427,7 +1427,7 @@ namespace lemon {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;

size_t iter_number=0;
uint64_t iter_number = 0;
//pivot.setDantzig(true);
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
Expand Down
6 changes: 3 additions & 3 deletions ot/lp/network_simplex_simple_omp.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ namespace lemon_omp {
/// mixed order in the internal data structure.
/// In special cases, it could lead to better overall performance,
/// but it is usually slower. Therefore it is disabled by default.
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, size_t maxiters = 0, int numThreads=-1) :
NetworkSimplexSimple(const GR& graph, bool arc_mixing, int nbnodes, ArcsType nb_arcs, uint64_t maxiters = 0, int numThreads=-1) :
_graph(graph), //_arc_id(graph),
_arc_mixing(arc_mixing), _init_nb_nodes(nbnodes), _init_nb_arcs(nb_arcs),
MAX(std::numeric_limits<Value>::max()),
Expand Down Expand Up @@ -317,7 +317,7 @@ namespace lemon_omp {


private:
size_t max_iter;
uint64_t max_iter;
int num_threads;
TEMPLATE_DIGRAPH_TYPEDEFS(GR);

Expand Down Expand Up @@ -1563,7 +1563,7 @@ namespace lemon_omp {
// Perform heuristic initial pivots
if (!initialPivots()) return UNBOUNDED;

size_t iter_number = 0;
uint64_t iter_number = 0;
// Execute the Network Simplex algorithm
while (pivot.findEnteringArc()) {
if ((++iter_number <= max_iter&&max_iter > 0) || max_iter<=0) {
Expand Down
38 changes: 20 additions & 18 deletions test/test_unbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,26 +295,27 @@ def test_mm_convergence(nx):
x = rng.randn(n, 2)
rng = np.random.RandomState(75)
y = rng.randn(n, 2)
a = ot.utils.unif(n)
b = ot.utils.unif(n)
a_np = ot.utils.unif(n)
b_np = ot.utils.unif(n)

M = ot.dist(x, y)
M = M / M.max()
reg_m = 100
a, b, M = nx.from_numpy(a, b, M)
a, b, M = nx.from_numpy(a_np, b_np, M)

G_kl, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl',
verbose=True, log=True)
loss_kl = nx.to_numpy(ot.unbalanced.mm_unbalanced2(
a, b, M, reg_m, div='kl', verbose=True))
verbose=False, log=True)
loss_kl = nx.to_numpy(
ot.unbalanced.mm_unbalanced2(a, b, M, reg_m, div='kl', verbose=True)
)
G_l2, _ = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2',
verbose=False, log=True)

# check if the marginals come close to the true ones when large reg
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 1), a_np, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_kl), 0), b_np, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 1), a_np, atol=1e-03)
np.testing.assert_allclose(np.sum(nx.to_numpy(G_l2), 0), b_np, atol=1e-03)

# check if mm_unbalanced2 returns the correct loss
np.testing.assert_allclose(nx.to_numpy(nx.sum(G_kl * M)), loss_kl,
Expand All @@ -324,15 +325,16 @@ def test_mm_convergence(nx):
a_np, b_np = np.array([]), np.array([])
a, b = nx.from_numpy(a_np, b_np)

G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl')
G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2')
np.testing.assert_allclose(G_kl_null, G_kl)
np.testing.assert_allclose(G_l2_null, G_l2)
G_kl_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', verbose=False)
G_l2_null = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', verbose=False)
np.testing.assert_allclose(nx.to_numpy(G_kl_null), nx.to_numpy(G_kl))
np.testing.assert_allclose(nx.to_numpy(G_l2_null), nx.to_numpy(G_l2))

# test when G0 is given
G0 = ot.emd(a, b, M)
G0_np = nx.to_numpy(G0)
reg_m = 10000
G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0)
G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0)
np.testing.assert_allclose(G0, G_kl, atol=1e-05)
np.testing.assert_allclose(G0, G_l2, atol=1e-05)
G_kl = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='kl', G0=G0, verbose=False)
G_l2 = ot.unbalanced.mm_unbalanced(a, b, M, reg_m=reg_m, div='l2', G0=G0, verbose=False)
np.testing.assert_allclose(G0_np, nx.to_numpy(G_kl), atol=1e-05)
np.testing.assert_allclose(G0_np, nx.to_numpy(G_l2), atol=1e-05)