Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/lr linear apply #91

Merged
merged 3 commits into from
Jun 30, 2022
Merged

Fix/lr linear apply #91

merged 3 commits into from
Jun 30, 2022

Conversation

michalk8
Copy link
Collaborator

Follow-up of #88

Require the know whether fn is linear at compile time for more efficient implementation.
I wasn't able to find a better solution which would maintain the jax.lax.cond, but I also slightly disliked the heuristic.
I've tried to figure out why the OOM exactly happens, as the inefficient branch shouldn't even be executed, I suspect it's either because the _apply_cost_to_vec is inside a vmap (jax-ml/jax#9543 (comment)) or because there was no data dependence between the condition and the branches (jax-ml/jax#3103 (comment) - I've tried different branching arguments + cond over vmap instead of vmap over cond, didn't help).
Now in case of GW, the squared Eucl./KL's components are manually annotated whether they are linear or not (don't think is too much hassle for the user + more robust than the current check).

I've kept the jax.lax.cond as default when is_linear=None for compatibility with LRSinkhorn, which will be fixed in a future PR.
Furthermore, same issue occurs in PointCloud here, see also the snippet below. This would require annotating the costs in ott.geometry.costs, which I think should also be done in a future PR.

Below is a small testing snippet, the current memory test for GW has been modified to FGW and passes with even lower memory:

from jax.config import config
config.update("jax_enable_x64", True)

import ott
import numpy as np
import jax.numpy as jnp

n = 300_000
d = 30
epsilon = 1e-2
np.random.seed(42)

x = jnp.asarray(np.random.normal(size=(n, d)))
geom_x = ott.geometry.pointcloud.PointCloud(x, epsilon=epsilon)
geom_x_lrc = geom_x.to_LRCGeometry()

# `geom_x.apply_cost` fails with:
# XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 1440000000048 bytes.
gt = geom_x.vec_apply_cost(x[:, 0], fn=lambda x: x * 10)
# `geom_x_lrc.apply_cost` with `is_linear=None` fails with:
# XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory allocating 720000000072 bytes.
pred = geom_x_lrc.apply_cost(x[:, 0], fn=lambda x: x * 10, is_linear=True)
assert jnp.allclose(gt, pred)

jax/jaxlib versions: '0.3.13', '0.3.10'

Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Michal, I think this is a good fix. I agree that the test to check linearity was really buggy, and it's not surprising this arises when trying to jit compile things...

Loss = Tuple[Tuple[LossTerm, LossTerm], Tuple[LossTerm, LossTerm]]
class Loss(NamedTuple):
func: Callable[[jnp.ndarray], jnp.ndarray]
is_linear: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it make sense to set it to False by default?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to keep this explicit so when new losses are added, one has to think about whether these terms are linear.

@michalk8
Copy link
Collaborator Author

Thanks Michal, I think this is a good fix. I agree that the test to check linearity was really buggy, and it's not surprising this arises when trying to jit compile things...

Great, will merge. Will do the same changes for PointCloud + enable this for LRSinkhorn in a future PR.

@michalk8 michalk8 self-assigned this Jun 30, 2022
@michalk8 michalk8 added the bug Something isn't working label Jun 30, 2022
@michalk8 michalk8 merged commit 43ce3f8 into ott-jax:main Jun 30, 2022
@michalk8 michalk8 deleted the fix/lr-linear-apply branch June 30, 2022 18:11
Copy link
Contributor

@marcocuturi marcocuturi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Michal

michalk8 added a commit that referenced this pull request Jun 27, 2024
* Refactor the way linear function is termined in LR

* Be backwards compatible, use linear heuristic

* Fix debug is_linear value, move test to FGW
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants