Skip to content

Commit

Permalink
Updated jax.config import
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 574896522
  • Loading branch information
superbobry authored and JAXopt authors committed Oct 23, 2023
1 parent b371982 commit 9703f61
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/implicit_diff/maml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"except (KeyError, RuntimeError):\n",
" print(\"TPU not found, continuing without it.\")\n",
"\n",
"from jax.config import config\n",
"from jax import config\n",
"config.update(\"jax_enable_x64\", True)\n",
"\n",
"import jax\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks/implicit_diff/maml.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ try:
except (KeyError, RuntimeError):
print("TPU not found, continuing without it.")
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)
import jax
Expand Down
2 changes: 1 addition & 1 deletion jaxopt/_src/scipy_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from typing import Union

import jax
from jax.config import config
from jax import config
import jax.numpy as jnp
import jax.tree_util as tree_util
from jax.tree_util import register_pytree_node_class
Expand Down
2 changes: 1 addition & 1 deletion tests/anderson_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import jax
import jax.numpy as jnp
from jax.config import config
from jax import config
from jax.tree_util import tree_map, tree_all
from jax.test_util import check_grads
import optax
Expand Down

0 comments on commit 9703f61

Please sign in to comment.