-
Notifications
You must be signed in to change notification settings - Fork 86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/lr linear apply #91
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Michal, I think this is a good fix. I agree that the test to check linearity was really buggy, and it's not surprising this arises when trying to jit compile things...
Loss = Tuple[Tuple[LossTerm, LossTerm], Tuple[LossTerm, LossTerm]] | ||
class Loss(NamedTuple): | ||
func: Callable[[jnp.ndarray], jnp.ndarray] | ||
is_linear: bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to set it to False by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to keep this explicit so when new losses are added, one has to think about whether these terms are linear.
Great, will merge. Will do the same changes for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Michal
* Refactor the way linear function is termined in LR * Be backwards compatible, use linear heuristic * Fix debug is_linear value, move test to FGW
Follow-up of #88
Require the know whether
fn
is linear at compile time for more efficient implementation.I wasn't able to find a better solution which would maintain the
jax.lax.cond
, but I also slightly disliked the heuristic.I've tried to figure out why the OOM exactly happens, as the inefficient branch shouldn't even be executed, I suspect it's either because the
_apply_cost_to_vec
is inside avmap
(jax-ml/jax#9543 (comment)) or because there was no data dependence between the condition and the branches (jax-ml/jax#3103 (comment) - I've tried different branching arguments +cond over vmap
instead ofvmap over cond
, didn't help).Now in case of GW, the squared Eucl./KL's components are manually annotated whether they are linear or not (don't think is too much hassle for the user + more robust than the current check).
I've kept the
jax.lax.cond
as default whenis_linear=None
for compatibility withLRSinkhorn
, which will be fixed in a future PR.Furthermore, same issue occurs in
PointCloud
here, see also the snippet below. This would require annotating the costs inott.geometry.costs
, which I think should also be done in a future PR.Below is a small testing snippet, the current memory test for GW has been modified to FGW and passes with even lower memory:
jax/jaxlib
versions: '0.3.13', '0.3.10'