Skip to content

Commit

Permalink
omnistaging on by default
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 14, 2020
1 parent bd14f23 commit 1bce9c3
Show file tree
Hide file tree
Showing 19 changed files with 876 additions and 943 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,26 +53,26 @@ jobs:
- python-version: 3.6
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 0
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 25
- python-version: 3.7
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
enable-omnistaging: 1
package-overrides: "none"
num_generated_cases: 25
- python-version: 3.6
os: ubuntu-latest
enable-x64: 1
enable-omnistaging: 0
enable-omnistaging: 1
# Test with numpy version that matches Google-internal version
package-overrides: "numpy==1.16.4"
num_generated_cases: 10
- python-version: 3.7
os: ubuntu-latest
enable-x64: 0
enable-omnistaging: 1
enable-omnistaging: 0
package-overrides: "none"
num_generated_cases: 8
steps:
Expand Down
209 changes: 63 additions & 146 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -162,37 +162,6 @@ before (with two input vars, one for each element of the input tuple)
in (f,) }



Constant Vars
--------------

ConstVars arise when the computation contains array constants, either
from the Python program, or from constant-folding. For example, the function
``func6`` below

>>> def func5(first, second):
... temp = first + jnp.sin(second) * 3. - jnp.ones(8)
... return temp
...
>>> def func6(first):
... return func5(first, jnp.ones(8))
...

JAX produces the following jaxpr

>>> print(make_jaxpr(func6)(jnp.ones(8)))
{ lambda b d ; a.
let c = add a b
e = sub c d
in (e,) }

When tracing ``func6``, the function ``func5`` is invoked with a constant value
(``np.ones(8)``) for the second argument. As a result, the sub-expression
``jnp.sin(second) * 3.`` is constant-folded.
There are two ConstVars, ``b`` (standing for ``jnp.sin(second) * 3.``) and ``d``
(standing for ``jnp.ones(8)``). Unfortunately, it is not easy to tell from the
jaxpr notation what constants the constant variables stand for.

Higher-order primitives
-----------------------

Expand Down Expand Up @@ -293,44 +262,25 @@ contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
>>> def func8(arg1, arg2): # arg2 is a pair
... return lax.cond(arg1 >= 0.,
... lambda xtrue: xtrue[0],
... lambda xfalse: jnp.ones(1) + xfalse[1],
... lambda xfalse: jnp.array([1]) + xfalse[1],
... arg2)
...
>>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
{ lambda f ; a b c.
let d = ge a 0.0
e = convert_element_type[ new_dtype=int32
old_dtype=bool ] d
g = cond[ branches=( { lambda ; c a b.
let d = add c b
in (d,) }
{ lambda ; e_ a b.
let
{ lambda a ; b c d.
let e = ge b 0.0
f = convert_element_type[ new_dtype=int32
old_dtype=bool ] e
g = cond[ branches=( { lambda ; a b c.
let d = convert_element_type[ new_dtype=float32
old_dtype=int32 ] a
e = add d c
in (e,) }
{ lambda ; f_ a b.
let
in (a,) } )
linear=(False, False, False) ] e f b c
linear=(False, False, False) ] f a c d
in (g,) }

The top-level jaxpr has one `constvar` ``f`` (corresponding to
``jnp.ones(1)`` from the body of the first (false) branch) and three
input variables ``a b c`` (corresponding to ``arg1`` and the two
elements of ``arg2``; note that ``arg2`` has been flattened). The
``false_jaxpr`` has three input variables (``c`` corresponding to the
constant for ``jnp.ones(1)``, and ``a b`` for the two elements of
``arg2`` that are passed to ``false_jaxpr``). The ``true_jaxpr`` has
three input variables. The first (``e_``) is an unused argument
matching the constant first argument ``c`` of ``false_jaxpr``
(required for the jaxpr signatures to match). The subsequent two
correspond to the two elements of ``arg2`` that is passed to
``true_jaxpr``.

The actual operands to the cond primitive are: ``e f b c``, which
correspond in order to:

* one operand for the predicate,
* one constant (only used by ``false_jaxpr``, but passed to both),
i.e., ``f``, which is a constvar for the top-level jaxpr
* two operands passed to both jaxprs, i.e., ``b`` and ``c``, which are
input vars, corresponding to ``arg2`` for the top-level jaxpr.

While
^^^^^
Expand All @@ -357,32 +307,22 @@ For example, here is an example fori loop
... arg + ones)
...
>>> print(make_jaxpr(func10)(np.ones(16), 5))
{ lambda c d ; a b.
let e = add a d
_ _ f = while[ body_jaxpr={ lambda ; e g a b c.
let d = add a 1
f = add c e
h = add f g
in (d, b, h) }
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d = add a c
_ _ e = while[ body_jaxpr={ lambda ; a b c d e.
let f = add c 1
g = mul a 3.0
h = add e g
i = add h b
in (f, d, i) }
body_nconsts=2
cond_jaxpr={ lambda ; a b c.
let d = lt a b
in (d,) }
cond_nconsts=0 ] c a 0 b e
in (f,) }

The top-level jaxpr has two constvars: ``c`` (corresponding to ``ones * 3.`` from the body
of the loop) and ``d`` (corresponding to the use of ``ones`` in the initial carry).
There are also two input variables (``a`` corresponding to ``arg`` and ``b`` corresponding
to ``n``).
The loop carry consists of three values, as seen in the body of ``cond_jaxpr``
(corresponding to the iteration index, iteration end, and the accumulated value carry).
Note that ``body_jaxpr`` takes 5 input variables. The first two are actually
constvars: ``e`` corresponding to ``ones * 3`` and ``g`` corresponding to the
captures use of ``arg`` in the loop body.
The parameter ``body_nconsts = 2`` specifies that there are 2 constants for the
``body_jaxpr``.
The other 3 input variables for ``body_jaxpr`` correspond to the flattened carry values.
cond_nconsts=0 ] c a 0 b d
in (e,) }

The while primitive takes 5 arguments: ``c a 0 b e``, as follows:

Expand All @@ -395,13 +335,13 @@ Scan

JAX supports a special form of loop over the elements of an array (with
statically known shape). The fact that there are a fixed number of iterations
makes this form of looping easily reverse-differentiable. Such loops are constructed
with the :py:func:`jax.lax.scan` operator::
makes this form of looping easily reverse-differentiable. Such loops are
constructed with the :py:func:`jax.lax.scan` function::

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

Here ``C`` is the type of the scan carry, ``A`` is the element type of the input array(s),
and ``B`` is the element type of the output array(s).
Here ``C`` is the type of the scan carry, ``A`` is the element type of the
input array(s), and ``B`` is the element type of the output array(s).

For the example consider the function ``func11`` below

Expand All @@ -415,12 +355,14 @@ For the example consider the function ``func11`` below
... return lax.scan(body, 0., (arr, ones))
...
>>> print(make_jaxpr(func11)(np.ones(16), 5.))
{ lambda c ; a b.
let d e = scan[ jaxpr={ lambda ; f a b c.
let d = mul b c
e = add a d
g = add e f
in (g, a) }
{ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(16,) ] 1.0
d e = scan[ jaxpr={ lambda ; a b c d.
let e = mul c d
f = add b e
g = add f a
in (g, b) }
length=16
linear=(False, False, False, False)
num_carry=1
Expand All @@ -429,17 +371,6 @@ For the example consider the function ``func11`` below
unroll=1 ] b 0.0 a c
in (d, e) }

The top-level jaxpr has one constvar ``c`` corresponding to the ``ones`` constant,
and two input variables corresponding to the arguments ``arr`` and ``extra``.
The body of the scan has 4 input variables, of which:

* one (``f``) is a constant (since ``num_consts = 1``), and stands for the
captured variable ``extra`` used in the loop body,
* one (``a``) is the value of the carry (since ``num_carry = 1``)
* The remaining 2 are the input values. ``b`` is the array element from the
first array passed to lax.scan (``arr``) and ``c`` is the second array
(``ones``).

The ``linear`` parameter describes for each of the input variables whether they
are guaranteed to be used linearly in the body. Once the scan goes through
linearization, more arguments will be linear.
Expand All @@ -466,37 +397,27 @@ computation should run. For example
... return arg + inner(arg - 2.)
...
>>> print(make_jaxpr(func12)(1.))
{ lambda b ; a.
let c = sub a 2.0
d = xla_call[ backend=None
call_jaxpr={ lambda ; c b a.
let d = mul b c
e = add a d
{ lambda ; a.
let b = sub a 2.0
c = xla_call[ backend=None
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = mul a c
e = add b d
in (e,) }
device=None
donated_invars=(False, False, False)
name=inner ] b a c
e = add a d
in (e,) }

The top-level constvar ``b`` refers to the ``jnp.ones(1)`` constant, and
the top-level input variable `a` refers to the ``arg`` parameter of ``func12``.
The ``xla_call`` primitive stands for a call to the jitted ``inner`` function.
The primitive has the function body in the ``call_jaxpr`` parameter, a jaxpr
with 3 input parameters:

* ``c`` is a constvar and stands for the ``ones`` constant,
* ``b`` corresponds to the free variable ``arg`` captured in the ``inner`` function,
* ``a`` corresponds to the ``inner`` parameter ``x``.
donated_invars=(False, False)
name=inner ] a b
d = add a c
in (d,) }

The primitive takes three arguments ``b a c``.

XLA_pmap
^^^^^^^^

If you use the :py:func:`jax.pmap` transformation, the function to be
mapped is captured using the ``xla_pmap`` primitive. Consider this
example
If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
captured using the ``xla_pmap`` primitive. Consider this example

>>> from jax import pmap
>>>
Expand All @@ -507,34 +428,30 @@ example
... return pmap(inner, axis_name='rows')(arr)
...
>>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
{ lambda c ; a b.
let d = xla_pmap[ axis_name=rows
{ lambda ; a b.
let c = xla_pmap[ axis_name=rows
axis_size=1
backend=None
call_jaxpr={ lambda ; d b a.
let c = add a b
call_jaxpr={ lambda ; a b.
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
e = add c d
f = psum[ axis_index_groups=None
axis_name=rows ] a
axis_name=rows ] b
g = div e f
in (g,) }
devices=None
donated_invars=(False, False, False)
donated_invars=(False, False)
global_axis_size=None
mapped_invars=(True, False, True)
name=inner ] c b a
in (d,) }
mapped_invars=(False, True)
name=inner ] b a
in (c,) }

The top-level constvar ``c`` refers to the ``jnp.ones(1)`` constant.
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
and the body of the function to be mapped as the ``call_jaxpr`` parameter. The
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
value of this parameter is a Jaxpr with 3 input variables:

* ``d`` stands for the constant ``jnp.ones(1)``,
* ``b`` stands for the free variable ``extra``,
* ``a`` stands for the parameter ``x`` of ``inner``.


The parameter ``mapped_invars`` specify which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast, the other input values are mapped.
19 changes: 10 additions & 9 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def disable_jit():
... return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)>
[5 7 9]
Here ``y`` has been abstracted by :py:func:`jit` to a :py:class:`ShapedArray`,
Expand Down Expand Up @@ -379,7 +379,7 @@ def computation_maker(*args, **kwargs):
else:
pvals = [pe.PartialVal.unknown(aval) for aval in avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, pvals, instantiate=True, stage_out=True)
jaxtree_fun, pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = [raise_to_shaped(pval.get_aval()) for pval in out_pvals]
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
Expand Down Expand Up @@ -1589,11 +1589,12 @@ def make_jaxpr(fun: Callable,
>>> jax.make_jaxpr(jax.grad(f))(3.0)
{ lambda ; a.
let b = cos a
c = cos b
d = mul 1.0 c
e = neg d
f = sin a
g = mul e f
c = sin a
_ = sin b
d = cos b
e = mul 1.0 d
f = neg e
g = mul f c
in (g,) }
"""
_check_callable(fun)
Expand All @@ -1614,7 +1615,7 @@ def jaxpr_maker(*args, **kwargs):
else:
in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
jaxtree_fun, in_pvals, instantiate=True, stage_out=True)
jaxtree_fun, in_pvals, instantiate=True, stage_out=True) # type: ignore
out_avals = map(raise_to_shaped, unzip2(out_pvals)[0])
typed_jaxpr = core.TypedJaxpr(jaxpr, consts, in_avals, out_avals)
return typed_jaxpr
Expand Down Expand Up @@ -1886,7 +1887,7 @@ def __call__(self, *args):
if config.omnistaging_enabled:
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
else:
with core.initial_style_staging():
with core.initial_style_staging(): # type: ignore
jaxpr, _, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, instantiate=True)
outs = self.prim.bind(*it.chain(consts, args_flat), jaxpr=jaxpr,
in_tree=in_tree, out_tree=out_tree(),
Expand Down
Loading

0 comments on commit 1bce9c3

Please sign in to comment.