Skip to content
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The contributors to this library are:
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Fixed an issue where pointers would overflow in the EMD solver, returning an
incomplete transport plan above a certain size (slightly above 46k, its square being
roughly 2^31) (PR #381)
- Error raised when mass mismatch in emd2 (PR #386)


## 0.8.2
Expand Down
9 changes: 9 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Parameters
Expand Down Expand Up @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

Uses the algorithm proposed in :ref:`[1] <references-emd2>`.

Parameters
Expand Down Expand Up @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"

# ensure that same mass
np.testing.assert_almost_equal(a.sum(0),
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
b = b * a.sum(0) / b.sum(0,keepdims=True)

asel = a != 0

numThreads = check_number_threads(numThreads)
Expand Down
3 changes: 3 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():

np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)

# test emd and emd2 for mass mismatch
a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)


def test_emd_backends(nx):
Expand Down