Skip to content
Merged
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Many other bugs and issues have been fixed and we want to thank all the contribu
- Fix issue backend for ot.sliced_wasserstein_sphere ot.sliced_wasserstein_sphere_unif (PR #471)
- Fix issue with ot.barycenter_stabilized when used with PyTorch tensors and log=True (PR #474)
- Fix `utils.cost_normalization` function issue to work with multiple backends (PR #472)
- Fix pression error on marginal sums and (Issue #429, PR #496)

#### New Contributors
* @kachayev made their first contribution in PR #462
Expand Down
27 changes: 19 additions & 8 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M):
return center_ot_dual(alpha, beta, a, b)


def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, check_marginals=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix


Expand Down Expand Up @@ -259,6 +259,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.


Returns
-------
Expand Down Expand Up @@ -328,9 +332,10 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
"Dimension mismatch, check dimensions of M with a and b"

# ensure that same mass
np.testing.assert_almost_equal(a.sum(0),
b.sum(0), err_msg='a and b vector must have the same sum',
decimal=6)
if check_marginals:
np.testing.assert_almost_equal(a.sum(0),
b.sum(0), err_msg='a and b vector must have the same sum',
decimal=6)
b = b * a.sum() / b.sum()

asel = a != 0
Expand Down Expand Up @@ -368,7 +373,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):

def emd2(a, b, M, processes=1,
numItermax=100000, log=False, return_matrix=False,
center_dual=True, numThreads=1):
center_dual=True, numThreads=1, check_marginals=True):
r"""Solves the Earth Movers distance problem and returns the loss

.. math::
Expand Down Expand Up @@ -425,7 +430,11 @@ def emd2(a, b, M, processes=1,
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.

check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.


Returns
-------
W: float, array-like
Expand Down Expand Up @@ -492,8 +501,10 @@ def emd2(a, b, M, processes=1,
"Dimension mismatch, check dimensions of M with a and b"

# ensure that same mass
np.testing.assert_almost_equal(a.sum(0),
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
if check_marginals:
np.testing.assert_almost_equal(a.sum(0),
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum',
decimal=6)
b = b * a.sum(0) / b.sum(0,keepdims=True)

asel = a != 0
Expand Down
17 changes: 11 additions & 6 deletions ot/lp/solver_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def wasserstein_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, requ


def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log=False):
log=False, check_marginals=True):
r"""Solves the Earth Movers distance problem between 1d measures and returns
the OT matrix

Expand Down Expand Up @@ -181,6 +181,9 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
log: boolean, optional (default=False)
If True, returns a dictionary containing the cost.
Otherwise returns only the optimal transportation matrix.
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.

Returns
-------
Expand Down Expand Up @@ -235,11 +238,13 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]

# ensure that same mass
np.testing.assert_almost_equal(
nx.to_numpy(nx.sum(a, axis=0)),
nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum'
)
if check_marginals:
np.testing.assert_almost_equal(
nx.to_numpy(nx.sum(a, axis=0)),
nx.to_numpy(nx.sum(b, axis=0)),
err_msg='a and b vector must have the same sum',
decimal=6
)
b = b * nx.sum(a) / nx.sum(b)

x_a_1d = nx.reshape(x_a, (-1,))
Expand Down