-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Address google/jax#19885 for numpyro. #1743
Address google/jax#19885 for numpyro. #1743
Conversation
numpyro/distributions/continuous.py
Outdated
@@ -1860,7 +1860,7 @@ def _batch_capacitance_tril(W, D): | |||
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2) | |||
K = jnp.matmul(Wt_Dinv, W) | |||
# could be inefficient | |||
return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1]))) | |||
return jnp.linalg.cholesky(jnp.subtract(K, -jnp.identity(K.shape[-1]))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting. I couldn't come up with a good solution to compute K + eye. Switching to subtract is a bit unfortunate to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, it's not great. I'll have a think if K.at[..., <diag indices>].add(1)
could do the job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, I couldn't reproduce the issue locally. :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to only appear on some systems. For example, I can't reproduce the issue on my MacBook Pro with M1 chip, but it is reproducible in GitHub Actions (see here for an example run).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is a comparison of different implementations. Using the at
indexing seems to do a reasonable job. I've just pushed an update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You would need to add .block_until_ready
for a fair comparision: see https://jax.readthedocs.io/en/latest/tutorials/quickstart.html#using-jit-to-speed-up-functions. I guess add/substract
is more performant than slice update.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good point. Forgot about that. Using the block_until_ready
call, I get
_original_batch_capacitance_tril
677 ms ± 26.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
_sub_batch_capacitance_tril
5.97 ms ± 713 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_add_batch_capacitance_tril
5.44 ms ± 900 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
on the machine that's raising the warning. On my local M1 chip, I get
_original_batch_capacitance_tril
11.5 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
_sub_batch_capacitance_tril
12 ms ± 716 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
_add_batch_capacitance_tril
8.41 ms ± 1.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Till!
15c4397
to
870d63a
Compare
numpyro/distributions/continuous.py
Outdated
@@ -71,7 +71,7 @@ | |||
validate_sample, | |||
vec_to_tril_matrix, | |||
) | |||
from numpyro.util import is_prng_key | |||
from numpyro.util import add_diag, is_prng_key |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool!!
Could you move this helper to numpyro.distributions.util
instead?
870d63a
to
96dad61
Compare
* Address jax-ml/jax#19885 for numpyro. * Implement function to add constant or batch of vectors to diagonal. * Use `add_diag` helper function in `distributions` module. * Move `add_diag` to `distributions.util` module.
This should address performance issues for
LowRankMultivariateNormal
distributions with batch dimensions. Only a single change was required to fix the issue. There are other[matmul] + [identity]
expressions in the codebase, but, for some reason, they don't cause any issues. The tests verify that no warning is emitted.