-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add Hungarian solver from optax #598
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #598 +/- ##
==========================================
+ Coverage 87.90% 88.23% +0.32%
==========================================
Files 73 74 +1
Lines 7770 7808 +38
Branches 556 556
==========================================
+ Hits 6830 6889 +59
+ Misses 799 778 -21
Partials 141 141
|
@@ -347,6 +347,34 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 | |||
return cls(*aux_data) | |||
|
|||
|
|||
@jtu.register_pytree_node_class | |||
class EuclideanP(TICost): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add to the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok. Don't want to surprise the user because there's no Legendre transform, but can add.
the first geometry sends mass to point :math:`j` in the second. | ||
""" | ||
geom: geometry.Geometry | ||
paired_indices: Optional[jnp.ndarray] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would remove the optional and default none.
src/ott/tools/unreg.py
Outdated
geom: geometry object | ||
paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs | ||
of indices, for which the optimal transport assigns mass. Namely, for each | ||
index :math:`0 <= k < n`, if one has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use \leq
src/ott/tools/unreg.py
Outdated
|
||
|
||
def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float: | ||
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would remove the convenience wrapper
part and rephrase as
Compute the :term:`Wasserstein distance` using the Hungarian algoritm.
or similar.
src/ott/tools/unreg.py
Outdated
|
||
|
||
def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]: | ||
"""Solve assignment problem using Hungarian as implemented in optax. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:mod:`optax`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please add this to intersphinx_mapping
in conf.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isn't it there already?
intersphinx_mapping = {
"python": ("https://docs.python.org/3", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"jaxopt": ("https://jaxopt.github.io/stable", None),
"lineax": ("https://docs.kidger.site/lineax/", None),
"flax": ("https://flax.readthedocs.io/en/latest/", None),
"optax": ("https://optax.readthedocs.io/en/latest/", None),
"diffrax": ("https://docs.kidger.site/diffrax/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
"pot": ("https://pythonot.github.io/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
}
tests/tools/unreg_test.py
Outdated
rng1, rng2 = jax.random.split(rng, 2) | ||
x, y = gen_data(rng1, n, m, dim) | ||
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005) | ||
cost_hung, out_hung = unreg.hungarian(geom) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add jax.jit
.
tests/tools/unreg_test.py
Outdated
x, y = gen_data(rng1, n, m, dim) | ||
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p)) | ||
cost_hung, _ = unreg.hungarian(geom) | ||
w_p = unreg.wassdis_p(x, y, p) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add jax.jit
.
src/ott/tools/unreg.py
Outdated
def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float: | ||
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`. | ||
|
||
Uses :func:`~ott.tools.unreg.hungarian` to solve the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use :func:`hungarian`.
src/ott/tools/unreg.py
Outdated
"""Solve assignment problem using Hungarian as implemented in optax. | ||
|
||
Args: | ||
geom: (square) geometry object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would say
Geometry of shape ``[n, n]``.
Take advantage of google-deepmind/optax#1083 to include unregularized OT solver for point clouds of equal size.