Pytrees play a central role in JAX:
Pytrees are a good fit to implement flexible tensors with tree-shaped indices (V-values) described in Sections 3 and 5.3 of "Dataflow Matrix Machines and V-values: a Bridge between Programs and Neural Nets",
JAX is capable of taking gradients with respect to variables accumulated within pytrees (see e.g. the last section ("Linear regression with Pytrees") of
Here is a simple test performed on February 22, 2022 (Ubuntu 20.04.3 LTS, Python 3.8.10, JAX 0.3.0, tree-math 0.1.0):
It is convenient to use the library.
We take a gradient of sum of tree_map(relu, x)
with respect to a nested dictionary x
Note that despite the JAX reputation of being "static", we can change the shape of the dictionary x
on the fly and grad_f
keeps working correctly.
>>> from jax.nn import relu
>>> from jax.tree_util import tree_map
>>> from jax import numpy as jnp
>>> import tree_math as tm
>>> from jax import grad
>>> d = {}
>>> d["x"] = jnp.array([3., -4])
>>> d["y"] = jnp.array([8., 0])
>>> d
{'x': DeviceArray([ 3., -4.], dtype=float32), 'y': DeviceArray([8., 0.], dtype=float32)}
>>> tree_map(relu, d)
{'x': DeviceArray([3., 0.], dtype=float32), 'y': DeviceArray([8., 0.], dtype=float32)}
>>> def f(x):
... return tm.Vector(tree_map(relu, x)).sum()
>>> f(d)
DeviceArray(11., dtype=float32)
>>> grad_f = grad(f)
>>> grad_f(d)
{'x': DeviceArray([1., 0.], dtype=float32), 'y': DeviceArray([1., 0.], dtype=float32)}
>>> d["deeper"] = {"inner": jnp.array([-7, 13, 0.])}
>>> d
{'x': DeviceArray([ 3., -4.], dtype=float32), 'y': DeviceArray([8., 0.], dtype=float32), 'deeper': {'inner': DeviceArray([-7., 13., 0.], dtype=float32)}}
>>> tree_map(relu, d)
{'deeper': {'inner': DeviceArray([ 0., 13., 0.], dtype=float32)}, 'x': DeviceArray([3., 0.], dtype=float32), 'y': DeviceArray([8., 0.], dtype=float32)}
>>> f(d)
DeviceArray(24., dtype=float32)
>>> grad_f(d)
{'deeper': {'inner': DeviceArray([0., 1., 0.], dtype=float32)}, 'x': DeviceArray([1., 0.], dtype=float32), 'y': DeviceArray([1., 0.], dtype=float32)}