Skip to content

Commit

Permalink
Merge pull request #2325 from levskaya:treesymbolfix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 463659871
  • Loading branch information
Flax Authors committed Jul 27, 2022
2 parents 8d3e987 + 4c2d87b commit 0740ef6
Show file tree
Hide file tree
Showing 63 changed files with 223 additions and 223 deletions.
6 changes: 3 additions & 3 deletions docs/advanced_topics/lift.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class ManualVmapMLP(nn.Module):

xs = jnp.ones((3, 4))
variables = ManualVmapMLP().init(random.PRNGKey(0), xs)
print(jax.tree_map(jnp.shape, variables['params']))
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
Expand Down Expand Up @@ -257,7 +257,7 @@ def lift_transpose(fn, target='params', variables=True, rngs=True):
if x.ndim == 2:
return x.T
return x
target = jax.tree_map(trans, target)
target = jax.tree_util.tree_map(trans, target)
variable_groups = (target, rest)
scope = scope_fn(variable_groups, rng_groups)
y = fn(scope, *args)
Expand Down Expand Up @@ -307,7 +307,7 @@ class LinenVmapMLP(nn.Module):
return VmapMLP(name='mlp')(xs)

variables = LinenVmapMLP().init(random.PRNGKey(0), xs)
print(jax.tree_map(jnp.shape, variables['params']))
print(jax.tree_util.tree_map(jnp.shape, variables['params']))
"""==>
{
mlp: {
Expand Down
6 changes: 3 additions & 3 deletions docs/advanced_topics/optax_update_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ becomes just another gradient transformation |optax.clip_by_global_norm()|_.

def train_step(optimizer, batch):
grads = jax.grad(loss)(optimizer.target, batch)
grads_flat, _ = jax.tree_flatten(grads)
grads_flat, _ = jax.tree_util.tree_flatten(grads)
global_l2 = jnp.sqrt(sum([jnp.vdot(p, p) for p in grads_flat]))
g_factor = jnp.minimum(1.0, grad_clip_norm / global_l2)
grads = jax.tree_map(lambda g: g * g_factor, grads)
grads = jax.tree_util.tree_map(lambda g: g * g_factor, grads)
return optimizer.apply_gradient(grads)

---
Expand Down Expand Up @@ -268,7 +268,7 @@ that is not readily available outside the outer mask).
kernels = flax.traverse_util.ModelParamTraversal(lambda p, _: 'kernel' in p)
biases = flax.traverse_util.ModelParamTraversal(lambda p, _: 'bias' in p)

all_false = jax.tree_map(lambda _: False, params)
all_false = jax.tree_util.tree_map(lambda _: False, params)
kernels_mask = kernels.update(lambda _: True, all_false)
biases_mask = biases.update(lambda _: True, all_false)

Expand Down
4 changes: 2 additions & 2 deletions docs/flip/1009-optimizer-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ Remarks:
`OptimizerDef`).
- The functions `init_param_state()` and `apply_param_gradient()` are called
for every leaf in the params/grads pytree. This makes it possible to write the
calculations directly without `jax.tree_map()`.
calculations directly without `jax.tree_util.tree_map()`.
- The interface was defined in pre-Linen without the distinction of `params` vs.
other collections in `variables` in mind. The original API was elegant because
one only needed to pass around the optimizer, which included the parameters,
Expand Down Expand Up @@ -500,5 +500,5 @@ rng = jax.random.PRNGKey(0)
ds = tfds.load('mnist')['train'].take(160).map(pp).batch(16)
batch = next(iter(ds))
variables = model.init(rng, jnp.array(batch['image'][:1]))
jax.tree_map(jnp.shape, variables)
jax.tree_util.tree_map(jnp.shape, variables)
```
4 changes: 2 additions & 2 deletions docs/flip/1777-default-dtype.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ A simplified example implementation:
```python
def promote_arrays(*xs, dtype):
if dtype is None:
dtype = jnp.result_type(*jax.tree_leaves(xs))
return jax.tree_map(lambda x: jnp.asarray(x, dtype), xs)
dtype = jnp.result_type(*jax.tree_util.tree_leaves(xs))
return jax.tree_util.tree_map(lambda x: jnp.asarray(x, dtype), xs)

Dtype = Any
class Dense(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@
"def eval_model(params, test_ds):\n",
" metrics = eval_step(params, test_ds)\n",
" metrics = jax.device_get(metrics)\n",
" summary = jax.tree_map(lambda x: x.item(), metrics)\n",
" summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)\n",
" return summary['loss'], summary['accuracy']"
]
},
Expand Down
2 changes: 1 addition & 1 deletion docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ Create a model evaluation function that:
def eval_model(params, test_ds):
metrics = eval_step(params, test_ds)
metrics = jax.device_get(metrics)
summary = jax.tree_map(lambda x: x.item(), metrics)
summary = jax.tree_util.tree_map(lambda x: x.item(), metrics)
return summary['loss'], summary['accuracy']
```

Expand Down
4 changes: 2 additions & 2 deletions docs/guides/extracting_intermediates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,12 @@ In the following code example we check if any intermediate activations are non-f
def predict(variables, x):
y, state = CNN().apply(variables, x, capture_intermediates=True, mutable=["intermediates"])
intermediates = state['intermediates']
fin = jax.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
fin = jax.tree_util.tree_map(lambda xs: jnp.all(jnp.isfinite(xs)), intermediates)
return y, fin

variables = init(jax.random.PRNGKey(0), batch)
y, is_finite = predict(variables, batch)
all_finite = all(jax.tree_leaves(is_finite))
all_finite = all(jax.tree_util.tree_leaves(is_finite))
assert all_finite, "non finite intermediate detected!"

By default only the intermediates of ``__call__`` methods are collected.
Expand Down
8 changes: 4 additions & 4 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
"key1, key2 = random.split(random.PRNGKey(0))\n",
"x = random.normal(key1, (10,)) # Dummy input\n",
"params = model.init(key2, x) # Initialization call\n",
"jax.tree_map(lambda x: x.shape, params) # Checking output shapes"
"jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes"
]
},
{
Expand Down Expand Up @@ -375,7 +375,7 @@
"\n",
"@jax.jit\n",
"def update_params(params, learning_rate, grads):\n",
" params = jax.tree_map(\n",
" params = jax.tree_util.tree_map(\n",
" lambda p, g: p - learning_rate * g, params, grads)\n",
" return params\n",
"\n",
Expand Down Expand Up @@ -665,7 +665,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -773,7 +773,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down
8 changes: 4 additions & 4 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ outputId: 06feb9d2-db50-4f41-c169-6df4336f43a5
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (10,)) # Dummy input
params = model.init(key2, x) # Initialization call
jax.tree_map(lambda x: x.shape, params) # Checking output shapes
jax.tree_util.tree_map(lambda x: x.shape, params) # Checking output shapes
```

+++ {"id": "NH7Y9xMEewmO"}
Expand Down Expand Up @@ -207,7 +207,7 @@ loss_grad_fn = jax.value_and_grad(mse)
@jax.jit
def update_params(params, learning_rate, grads):
params = jax.tree_map(
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
return params
Expand Down Expand Up @@ -352,7 +352,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -414,7 +414,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
```

Expand Down
6 changes: 3 additions & 3 deletions docs/guides/jax_for_the_impatient.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@
"id": "nW1IKnjqXFdN"
},
"source": [
"Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_map`:"
"Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_util.tree_map`:"
]
},
{
Expand Down Expand Up @@ -1164,7 +1164,7 @@
"# Always remember to jit!\n",
"@jax.jit\n",
"def update_params_pytree(params, learning_rate, x_samples, y_samples):\n",
" params = jax.tree_map(\n",
" params = jax.tree_util.tree_map(\n",
" lambda p, g: p - learning_rate * g, params,\n",
" jax.grad(mse_pytree)(params, x_samples, y_samples))\n",
" return params\n",
Expand Down Expand Up @@ -1202,7 +1202,7 @@
"for i in range(101):\n",
" # Note that here the loss is computed before the param update.\n",
" loss_val, grads = loss_grad_fn(params, x_samples, y_samples)\n",
" params = jax.tree_map(\n",
" params = jax.tree_util.tree_map(\n",
" lambda p, g: p - learning_rate * g, params, grads)\n",
" if (i % 5 == 0):\n",
" print(f\"Loss step {i}: \", loss_val)"
Expand Down
6 changes: 3 additions & 3 deletions docs/guides/jax_for_the_impatient.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ jax.grad(mse_pytree)(params, x_samples, y_samples)

+++ {"id": "nW1IKnjqXFdN"}

Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_map`:
Now using our tree of params, we can write the gradient descent in a simpler way using `jax.tree_util.tree_map`:

```{code-cell}
---
Expand All @@ -588,7 +588,7 @@ outputId: f309aff7-2aad-453f-ad88-019d967d4289
# Always remember to jit!
@jax.jit
def update_params_pytree(params, learning_rate, x_samples, y_samples):
params = jax.tree_map(
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params,
jax.grad(mse_pytree)(params, x_samples, y_samples))
return params
Expand Down Expand Up @@ -616,7 +616,7 @@ loss_grad_fn = jax.value_and_grad(mse_pytree)
for i in range(101):
# Note that here the loss is computed before the param update.
loss_val, grads = loss_grad_fn(params, x_samples, y_samples)
params = jax.tree_map(
params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g, params, grads)
if (i % 5 == 0):
print(f"Loss step {i}: ", loss_val)
Expand Down
10 changes: 5 additions & 5 deletions docs/guides/model_surgery.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Let's create a small convolutional neural network model for our demo.
key = jax.random.PRNGKey(0)
params = get_initial_params(key)

print(jax.tree_map(jnp.shape, params))
print(jax.tree_util.tree_map(jnp.shape, params))

.. testoutput::

Expand Down Expand Up @@ -77,7 +77,7 @@ Next, get a flat dict for doing model surgery as follows:

# Get flattened-key: value list.
flat_params = traverse_util.flatten_dict(params)
print(jax.tree_map(jnp.shape, flat_params))
print(jax.tree_util.tree_map(jnp.shape, flat_params))

.. testoutput::
:options: +NORMALIZE_WHITESPACE
Expand All @@ -99,7 +99,7 @@ After doing whatever you want, unflatten back:
unflat_params = traverse_util.unflatten_dict(flat_params)
# Refreeze.
unflat_params = freeze(unflat_params)
print(jax.tree_map(jnp.shape, unflat_params))
print(jax.tree_util.tree_map(jnp.shape, unflat_params))

.. testoutput::
:options: +NORMALIZE_WHITESPACE
Expand Down Expand Up @@ -138,7 +138,7 @@ optimizer state that mirrors the original state.
opt_state = tx.init(params)

# The optimizer state is a tuple of gradient transformation states.
print(jax.tree_map(jnp.shape, opt_state))
print(jax.tree_util.tree_map(jnp.shape, opt_state))

.. testoutput::
:options: +NORMALIZE_WHITESPACE
Expand All @@ -163,7 +163,7 @@ parameters and can be flattened / modified exactly the same way
flat_mu = traverse_util.flatten_dict(opt_state[0].mu)
flat_nu = traverse_util.flatten_dict(opt_state[0].nu)

print(jax.tree_map(jnp.shape, flat_mu))
print(jax.tree_util.tree_map(jnp.shape, flat_mu))

.. testoutput::
:options: +NORMALIZE_WHITESPACE
Expand Down
14 changes: 7 additions & 7 deletions docs/notebooks/linen_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -370,7 +370,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -762,7 +762,7 @@
"\n",
"print('updated variables:\\n', updated_variables)\n",
"print('initialized variable shapes:\\n', \n",
" jax.tree_map(jnp.shape, init_variables))\n",
" jax.tree_util.tree_map(jnp.shape, init_variables))\n",
"print('output:\\n', y)\n",
"\n",
"# Let's run these model variables during \"evaluation\":\n",
Expand Down Expand Up @@ -847,7 +847,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -919,7 +919,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -1076,7 +1076,7 @@
" batch_axes=(0,))\n",
"\n",
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"\n",
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
"print('output:\\n', y.shape)"
Expand Down Expand Up @@ -1153,7 +1153,7 @@
"model = SimpleScan()\n",
"init_variables = model.init(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"\n",
"y = model.apply(init_variables, xs)\n",
"print('output:\\n', y)"
Expand Down
Loading

0 comments on commit 0740ef6

Please sign in to comment.