Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Hungarian solver from optax #598

Merged
merged 10 commits into from
Nov 19, 2024
Merged

add Hungarian solver from optax #598

merged 10 commits into from
Nov 19, 2024

Conversation

marcocuturi
Copy link
Contributor

Take advantage of google-deepmind/optax#1083 to include unregularized OT solver for point clouds of equal size.

Copy link

codecov bot commented Nov 18, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 88.23%. Comparing base (fd18299) to head (02193ef).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #598      +/-   ##
==========================================
+ Coverage   87.90%   88.23%   +0.32%     
==========================================
  Files          73       74       +1     
  Lines        7770     7808      +38     
  Branches      556      556              
==========================================
+ Hits         6830     6889      +59     
+ Misses        799      778      -21     
  Partials      141      141              
Files with missing lines Coverage Δ
src/ott/geometry/costs.py 97.28% <100.00%> (+0.09%) ⬆️
src/ott/geometry/geometry.py 94.46% <ø> (ø)
src/ott/math/utils.py 91.30% <ø> (ø)
src/ott/tools/unreg.py 100.00% <100.00%> (ø)

... and 1 file with indirect coverage changes

---- 🚨 Try these New Features:

@marcocuturi marcocuturi requested a review from michalk8 November 18, 2024 22:42
@marcocuturi marcocuturi added the enhancement New feature or request label Nov 18, 2024
@@ -347,6 +347,34 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*aux_data)


@jtu.register_pytree_node_class
class EuclideanP(TICost):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add to the docs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. Don't want to surprise the user because there's no Legendre transform, but can add.

the first geometry sends mass to point :math:`j` in the second.
"""
geom: geometry.Geometry
paired_indices: Optional[jnp.ndarray] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove the optional and default none.

geom: geometry object
paired_indices: Array of shape ``[2, n]``, of :math:`n` pairs
of indices, for which the optimal transport assigns mass. Namely, for each
index :math:`0 <= k < n`, if one has
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use \leq



def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float:
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would remove the convenience wrapper part and rephrase as

Compute the :term:`Wasserstein distance` using the Hungarian algoritm.

or similar.



def hungarian(geom: geometry.Geometry) -> Tuple[jnp.ndarray, HungarianOutput]:
"""Solve assignment problem using Hungarian as implemented in optax.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:mod:`optax`

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please add this to intersphinx_mapping in conf.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it there already?

intersphinx_mapping = {
    "python": ("https://docs.python.org/3", None),
    "numpy": ("https://numpy.org/doc/stable/", None),
    "jax": ("https://jax.readthedocs.io/en/latest/", None),
    "jaxopt": ("https://jaxopt.github.io/stable", None),
    "lineax": ("https://docs.kidger.site/lineax/", None),
    "flax": ("https://flax.readthedocs.io/en/latest/", None),
    "optax": ("https://optax.readthedocs.io/en/latest/", None),
    "diffrax": ("https://docs.kidger.site/diffrax/", None),
    "scipy": ("https://docs.scipy.org/doc/scipy/", None),
    "pot": ("https://pythonot.github.io/", None),
    "matplotlib": ("https://matplotlib.org/stable/", None),
}

rng1, rng2 = jax.random.split(rng, 2)
x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=.0005)
cost_hung, out_hung = unreg.hungarian(geom)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add jax.jit.

x, y = gen_data(rng1, n, m, dim)
geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p))
cost_hung, _ = unreg.hungarian(geom)
w_p = unreg.wassdis_p(x, y, p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add jax.jit.

def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, p: float = 2.0) -> float:
"""Convenience wrapper on `hungarian` to get :term:`Wasserstein distance`.

Uses :func:`~ott.tools.unreg.hungarian` to solve the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use :func:`hungarian`.

"""Solve assignment problem using Hungarian as implemented in optax.

Args:
geom: (square) geometry object.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would say

Geometry of shape ``[n, n]``.

src/ott/tools/unreg.py Show resolved Hide resolved
@michalk8 michalk8 merged commit b479e5f into main Nov 19, 2024
9 of 12 checks passed
@michalk8 michalk8 deleted the hungarian branch November 19, 2024 15:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants