Skip to content

Commit

Permalink
avoid breakage in old jax version without jax.extend (#1647)
Browse files Browse the repository at this point in the history
* avoid breakage in old jax version without jax.extend

* fix lint
  • Loading branch information
fehiepsi authored Sep 22, 2023
1 parent 6e3f007 commit 59a188d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 5 additions & 1 deletion numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from jax.api_util import flatten_fun, shaped_abstractify
import jax.core as core
from jax.experimental.pjit import pjit_p
import jax.extend.linear_util as lu

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu
from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp
Expand Down
6 changes: 5 additions & 1 deletion test/ops/test_provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import jax
from jax.api_util import flatten_fun_nokwargs
import jax.core as core
import jax.extend.linear_util as lu

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu
import jax.numpy as jnp

from numpyro.ops.provenance import eval_provenance
Expand Down

0 comments on commit 59a188d

Please sign in to comment.