Skip to content

Commit

Permalink
Address jax-ml/jax#19885 for numpyro.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 27, 2024
1 parent 00f43d0 commit ff28690
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
2 changes: 1 addition & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,7 +1860,7 @@ def _batch_capacitance_tril(W, D):
Wt_Dinv = jnp.swapaxes(W, -1, -2) / jnp.expand_dims(D, -2)
K = jnp.matmul(Wt_Dinv, W)
# could be inefficient
return jnp.linalg.cholesky(jnp.add(K, jnp.identity(K.shape[-1])))
return jnp.linalg.cholesky(jnp.subtract(K, -jnp.identity(K.shape[-1])))


def _batch_lowrank_logdet(W, D, capacitance_tril):
Expand Down
58 changes: 55 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from itertools import product
import math
import os
from typing import Callable

import numpy as np
from numpy.testing import assert_allclose, assert_array_equal
Expand Down Expand Up @@ -3070,9 +3071,11 @@ def sample(d: dist.Distribution):

for in_axes, out_axes in in_out_axes_cases:
batched_params = [
jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg)
if isinstance(ax, int)
else arg
(
jax.tree_map(lambda x: jnp.expand_dims(x, ax), arg)
if isinstance(ax, int)
else arg
)
for arg, ax in zip(params, in_axes)
]
# Recreate the jax_dist to avoid side effects coming from `d.sample`
Expand Down Expand Up @@ -3169,3 +3172,52 @@ def test_sample_truncated_normal_in_tail():
def test_jax_custom_prng():
samples = dist.Normal(0, 5).sample(random.PRNGKey(0), sample_shape=(1000,))
assert ~jnp.isinf(samples).any()


def _assert_not_jax_issue_19885(
capfd: pytest.CaptureFixture, func: Callable, *args, **kwargs
) -> None:
# jit-ing identity plus matrix multiplication leads to performance degradation as
# discussed in https://github.com/google/jax/issues/19885. This assertion verifies
# that the issue does not affect perforance in numpyro.
for jit in [True, False]:
result = jax.jit(func)(*args, **kwargs)
block_until_ready = getattr(result, "block_until_ready", None)
if block_until_ready:
result = block_until_ready()
_, err = capfd.readouterr()
assert (
"MatMul reference implementation being executed" not in err
), f"jit: {jit}"
return result


@pytest.mark.xfail
def test_jax_issue_19885(capfd: pytest.CaptureFixture) -> None:
def func_with_warning(y) -> jnp.ndarray:
return jnp.identity(y.shape[-1]) + jnp.matmul(y, y)

_assert_not_jax_issue_19885(capfd, func_with_warning, jnp.ones((20, 100, 100)))


def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
# Create parameters.
batch_size = 100
event_size = 200
sample_size = 40
rank = 40
loc, cov_diag = random.normal(random.key(0), (2, batch_size, event_size))
cov_diag = jnp.exp(cov_diag)
cov_factor = random.normal(random.key(1), (batch_size, event_size, rank))

distribution = _assert_not_jax_issue_19885(
capfd, dist.LowRankMultivariateNormal, loc, cov_factor, cov_diag
)
x = _assert_not_jax_issue_19885(
capfd,
lambda x: distribution.sample(random.key(0), x.shape),
jnp.empty(sample_size),
)
assert x.shape == (sample_size, batch_size, event_size)
log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
assert log_prob.shape == (sample_size, batch_size)

0 comments on commit ff28690

Please sign in to comment.