Skip to content

Commit

Permalink
Export jax.tree* to avoid breaking users.
Browse files Browse the repository at this point in the history
  • Loading branch information
hawkinsp committed May 7, 2020
1 parent e892ad0 commit 68daf38
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
2 changes: 1 addition & 1 deletion docs/notebooks/How_JAX_primitives_work.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,7 @@
"\n",
"# Now we register the XLA compilation rule with JAX\n",
"# TODO: for GPU? and TPU?\n",
"from jax import xla\n",
"from jax.interpreters import xla\n",
"xla.backend_specific_translations['cpu'][multiply_add_p] = multiply_add_xla_translation"
],
"execution_count": 0,
Expand Down
11 changes: 10 additions & 1 deletion jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
ShapedArray,
ShapeDtypeStruct,
soft_pmap,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
tree_flatten,
tree_leaves,
tree_map,
tree_multimap,
tree_structure
tree_transpose,
tree_unflatten,
value_and_grad,
vjp,
vmap,
Expand All @@ -62,7 +72,6 @@
def _init():
import os
os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
del os

import jax.numpy # side-effecting import sets up operator overloads

Expand Down

0 comments on commit 68daf38

Please sign in to comment.