-
Notifications
You must be signed in to change notification settings - Fork 528
[MRG] Fix ordering #139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Fix ordering #139
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch ! I just have a minor comment. Also I am not sure if it is better to have two different tests for the variants with and without weights, but this is not very important I guess.
ot/lp/__init__.py
Outdated
@@ -656,7 +656,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, | |||
perm_a = np.argsort(x_a_1d) | |||
perm_b = np.argsort(x_b_1d) | |||
|
|||
G_sorted, indices, cost = emd_1d_sorted(a, b, | |||
G_sorted, indices, cost = emd_1d_sorted(a[perm_a.flatten()], b[perm_b.flatten()], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you do not have to flatten the permutation indices since they are computed from 1d arrays, or do you ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain why you flatten here?
Apart from that, LGTM!
test/test_ot.py
Outdated
np.testing.assert_allclose(wass, wass1d_emd2) | ||
|
||
# check loss is similar to scipy's implementation for Euclidean metric | ||
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You forgot the weights here, which is probably why the test fails at the moment
Errr, sorry about that. Tbh I ran the asserts in the first comment but not
the test in the PR :)
…On Thu, 2 Apr 2020, 07:56 Romain Tavenard, ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In test/test_ot.py
<#139 (comment)>:
> +
+ M = ot.dist(u, v, metric='sqeuclidean')
+
+ G, log = ot.emd(w_u, w_v, M, log=True)
+ wass = log["cost"]
+ G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
+ wass1d = log["cost"]
+ wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
+ wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)
+
+ # check loss is similar
+ np.testing.assert_allclose(wass, wass1d)
+ np.testing.assert_allclose(wass, wass1d_emd2)
+
+ # check loss is similar to scipy's implementation for Euclidean metric
+ wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
You forgot the weights here, which is probably why the test fails at the
moment
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#139 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AEYGFZ43WKD5O4NXGICOKNLRKQZKLANCNFSM4LYZNRDA>
.
|
Hello @AdrienCorenflos, Thank you for finding the bug and the PR. In addition to the comments from @rtavenar be careful to check pep8 on test_ot.py or else the tests will fail. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM ! Thanks a lot for this bugfix !
Fixes #138