Skip to content

Commit

Permalink
Merge pull request #2734 from google/tycheck
Browse files Browse the repository at this point in the history
typecheck jaxprs
  • Loading branch information
froystig authored May 22, 2020
2 parents 77e4d8b + c293a10 commit 96c20f3
Show file tree
Hide file tree
Showing 6 changed files with 206 additions and 63 deletions.
187 changes: 159 additions & 28 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,10 @@ def __repr__(self): return '*'
unit = Unit()
literalable_types.add(Unit)

class UnitVar(object):
class UnitVar(Var):
count = -1
suffix = None
def __init__(self): pass
@property
def aval(self): return abstract_unit
def __repr__(self): return '*'
Expand Down Expand Up @@ -1048,48 +1051,176 @@ def call_impl(f: lu.WrappedFun, *args, **params):
call_p.def_impl(call_impl)


# ------------------- Jaxpr printed representation -------------------
# ------------------- Jaxpr checking -------------------

def mapped_aval(size, aval):
if aval is abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(f"Mapped operand {aval}")

def unmapped_aval(size, aval):
if aval is abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
return ShapedArray((size,) + aval.shape, aval.dtype)
else:
raise TypeError(f"Mapped output {aval}")

def typecheck(aval, x):
return typecompat(aval, get_aval(x))

def typecompat(aval_ref, aval):
"""Determine whether `aval` conforms to `aval_ref`"""
aval_ref = raise_to_shaped(aval_ref).strip_weak_type()
try:
return aval_ref == lattice_join(aval_ref, aval).strip_weak_type()
except TypeError:
return False

def typematch(aval1, aval2):
return (raise_to_shaped(aval1).strip_weak_type() ==
raise_to_shaped(aval2).strip_weak_type())

# For use in Jaxpr typechecking (under `check_jaxpr`)
class _JaxprTypeEnvironment(object):
__slots__ = ["env"]

def __init__(self):
self.env: Dict[Var, AbstractValue] = {}

def read(self, v: Var):
env = self.env
if type(v) is not Literal:
if v not in env:
raise TypeError(
"Variable '{}' not defined".format(v))
if v.aval != env[v]:
raise TypeError(
"Variable '{}' inconsistently typed as {}, bound as {}".format(
v, v.aval, env[v]))
return v

def write(self, v: Var):
env = self.env
if v in env:
raise TypeError(
"Variable {} already bound".format(v))
env[v] = v.aval
return v

def check_jaxpr(jaxpr: Jaxpr):
"""Checks well-formedness of a jaxpr.
Specifically it checks that all variabled used are previously defined.
"""
def context():
return "\njaxpr:\n{}\n".format(jaxpr)
Specifically, check that:
- variables that are read are bound beforehand
- variables are typed equally throughout a jaxpr
- variable type annotations are compatible with their binding expression
def read_env(env: Set[Var], v: Var):
if type(v) is not Literal and v not in env:
raise Exception("Variable '{}' not defined".format(v) + context())
Raises `TypeError` if `jaxpr` is determined invalid. Returns `None` otherwise.
"""
try:
_check_jaxpr(jaxpr)
except Exception as e:
exception_type = type(e)
msg_context = f"while checking jaxpr:\n\n{jaxpr}\n"
if len(e.args) == 0:
exception_args = [msg_context]
else:
msg = f"{e.args[0]}\n\n" + msg_context
exception_args = [msg, *e.args[1:]]
raise exception_type(*exception_args) from e

def write_env(env: Set[Var], v: Var):
if v in env:
raise Exception("Variable {} already bound".format(v) + context())
env.add(v)
def _check_jaxpr(jaxpr: Jaxpr):
env = _JaxprTypeEnvironment()

env: Set[Var] = set()
read = partial(read_env, env)
write = partial(write_env, env)
env.write(unitvar)
map(env.write, jaxpr.constvars)
map(env.write, jaxpr.invars)

write(unitvar)
map(write, jaxpr.constvars)
map(write, jaxpr.invars)
for eqn in jaxpr.eqns:
if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
if "call_jaxpr" not in eqn.params:
raise Exception("Call primitive {} should have a 'call_jaxpr' parameter"
.format(eqn.primitive))
map(read, eqn.invars)
map(write, eqn.outvars)
check_jaxpr_eqn(env, eqn)

for subjaxpr in subjaxprs(jaxpr):
check_jaxpr(subjaxpr)
_check_jaxpr(subjaxpr)

map(env.read, jaxpr.outvars)

def _valid_eqn_assignment(dst_aval, src_aval):
# TODO(frostig): we'd rather this check simply be `typecompat` and not allow
# assignment to an AbstractUnit, but partial_eval.tracers_to_jaxpr types eqn
# outvars as AbstractUnit if the outvars are unused.
return dst_aval is abstract_unit or typecompat(dst_aval, src_aval)

def check_jaxpr_eqn(env, eqn):
invars = map(env.read, eqn.invars)
inferred_out_avals = type_transfer(eqn.primitive, invars, eqn.params)
outvars = map(env.write, eqn.outvars)

for outvar, inferred_out_aval in zip(outvars, inferred_out_avals):
if not _valid_eqn_assignment(outvar.aval, inferred_out_aval):
raise TypeError(
f"Jaxpr equation LHS {outvar} is {outvar.aval}, "
f"RHS is inferred as {inferred_out_aval}, in '{eqn}'")

def type_transfer(prim, invars, params):
in_avals = [v.aval for v in invars]

if prim.call_primitive or prim.map_primitive:
if "call_jaxpr" not in params:
raise TypeError(
f"Call primitive {prim} missing 'call_jaxpr' parameter")

if prim.map_primitive:
if "axis_size" not in params:
raise TypeError(
f"Map primitive {prim} missing 'axis_size' parameter")
if "mapped_invars" not in params:
raise TypeError(
f"Map primitive {prim} missing 'mapped_invars' parameter")

call_jaxpr = params["call_jaxpr"]
if len(invars) != len(call_jaxpr.invars):
raise TypeError(
f"Call primitive {prim} with {len(invars)} operands "
f"cannot call jaxpr with {len(call_jaxpr.invars)} invars")

binder_avals = [v.aval for v in call_jaxpr.invars]

if prim.map_primitive:
axis_size = params["axis_size"]
mapped_invars = params["mapped_invars"]
binder_avals = [unmapped_aval(axis_size, aval) if mapped else aval
for aval, mapped in zip(binder_avals, mapped_invars)]

for binder_aval, in_aval in zip(binder_avals, in_avals):
if not typecompat(binder_aval, in_aval):
raise TypeError(
f"Call primitive {prim} passes operand {in_aval} "
f"to jaxpr expecting {binder_aval}")

out_avals = [v.aval for v in call_jaxpr.outvars]

if prim.map_primitive:
axis_size = params["axis_size"]
out_avals = [unmapped_aval(axis_size, aval) for aval in out_avals]
else:
out_avals = prim.abstract_eval(*in_avals, **params)

map(read, jaxpr.outvars)
if not prim.multiple_results:
out_avals = [out_avals]

return out_avals


# ------------------- Jaxpr printed representation -------------------

def pp_vars(vs) -> str:
return ' '.join(map(str, vs))
return ' '.join(map(str, vs))

def pp_eqn_compact(primitive_name: str, params: Dict) -> PrettyPrint:
filtered_params = {k: v for k, v in params.items()
Expand Down
25 changes: 4 additions & 21 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,13 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
params = dict(params, name=name)
in_pvs, in_consts = unzip2([t.pval for t in tracers])
reduced_pvs = [None if pv is None else
_mapped_aval(params['axis_size'], pv) if m else pv
core.mapped_aval(params['axis_size'], pv) if m else pv
for pv, m in zip(in_pvs, params['mapped_invars'])]
fun, aux = partial_eval(f, self, reduced_pvs)
out_flat = map_primitive.bind(fun, *in_consts, **params)
out_pvs_reduced, jaxpr, env = aux()
out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])
out_pvs = [None if pv is None else _unmapped_aval(params['axis_size'], pv)
out_pvs = [None if pv is None else core.unmapped_aval(params['axis_size'], pv)
for pv in out_pvs_reduced]
const_tracers = map(self.new_instantiated_const, consts)
env_tracers = map(self.full_raise, env)
Expand All @@ -261,7 +261,8 @@ def process_map(self, map_primitive, f: lu.WrappedFun, tracers, params):
def post_process_map(self, map_primitive, out_tracers, params):
jaxpr, consts, env = tracers_to_jaxpr([], out_tracers)
out_pvs_reduced, out_pv_consts = unzip2(t.pval for t in out_tracers)
out_pvs = [None if pv is None else _unmapped_aval(params['axis_size'], pv)
out_pvs = [None if pv is None
else core.unmapped_aval(params['axis_size'], pv)
for pv in out_pvs_reduced]
out = out_pv_consts + consts
del consts, out_pv_consts
Expand Down Expand Up @@ -305,24 +306,6 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees):
class StagingJaxprTrace(JaxprTrace):
pass

def _mapped_aval(size, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
# might be raising abstraction level from Concrete here
assert aval.shape[0] == size
return ShapedArray(aval.shape[1:], aval.dtype)
else:
raise TypeError(aval)

def _unmapped_aval(size, aval):
if aval is core.abstract_unit:
return aval
elif isinstance(aval, ShapedArray):
return ShapedArray((size,) + aval.shape, aval.dtype)
else:
raise TypeError(aval)

custom_partial_eval_rules: Dict[core.Primitive, Callable] = {}
call_partial_eval_rules: Dict[core.Primitive, Callable] = {}
staged_out_calls: Set[core.Primitive] = set()
Expand Down
2 changes: 2 additions & 0 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,8 @@ def _axis_index_translation_rule(c, nreps, sizes, soft_size, axis_name):

axis_index_p = core.Primitive('axis_index')
axis_index_p.def_custom_bind(_axis_index_bind)
axis_index_p.def_abstract_eval(
lambda *args, **params: ShapedArray((), onp.int32))
xla.translations[axis_index_p] = _axis_index_translation_rule


Expand Down
13 changes: 1 addition & 12 deletions jax/lax/lax_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from jax import linear_util as lu
from jax.abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from jax.api_util import flatten_fun_nokwargs, apply_flat_fun_nokwargs
from jax.core import get_aval
from jax.core import get_aval, typecheck, typematch
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
Expand Down Expand Up @@ -119,17 +119,6 @@ def type_and_const_convert_jaxpr(jaxpr, out_pvals):
def _abstractify(x):
return raise_to_shaped(core.get_aval(x))

def typecheck(aval, x):
aval = raise_to_shaped(aval).strip_weak_type()
try:
return aval == core.lattice_join(aval, core.get_aval(x)).strip_weak_type()
except TypeError:
return False

def typematch(aval1, aval2):
return (raise_to_shaped(aval1).strip_weak_type() ==
raise_to_shaped(aval2).strip_weak_type())

def _disable_jit_impl(prim, interp, *args, **kwargs):
if jax.api._jit_is_disabled():
return interp(*args, **kwargs)
Expand Down
3 changes: 2 additions & 1 deletion jax/lax/lax_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,8 @@ def _translate(val):
return xops.Tuple(c, list(map(_translate, args)))

psum_p = standard_pmap_primitive('psum', multiple_results=True)
psum_p.def_abstract_eval(lambda *args, **params: map(raise_to_shaped, args))
psum_p.def_abstract_eval(
lambda *args, **params: tuple(map(raise_to_shaped, args)))
pxla.split_axis_rules[psum_p] = \
partial(_allreduce_split_axis_rule, psum_p, lax._reduce_sum)
xla.parallel_translations[psum_p] = _psum_translation_rule
Expand Down
39 changes: 38 additions & 1 deletion tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from jax import core
from jax import numpy as jnp
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.abstract_arrays import make_shaped_array
from jax.api import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce, tree_leaves
from jax.util import partial
Expand Down Expand Up @@ -304,5 +305,41 @@ def test_var_tree_flatten(self):
syms = {c: d, a: b}
assert 'bd' == ''.join(map(str, tree_leaves(syms)))

def test_check_jaxpr_correct(self):
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
core.check_jaxpr(jaxpr)

def test_check_jaxpr_eqn_mismatch(self):
def f(x):
return jnp.sin(x) + jnp.cos(x)

def new_jaxpr():
return make_jaxpr(f)(1.).jaxpr

# jaxpr is:
#
# { lambda ; a.
# let b = sin a
# c = cos a
# d = add b c
# in (d,) }
#
# NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2) # int, not float!
jtu.check_raises_regexp(
lambda: core.check_jaxpr(jaxpr),
TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), "
"RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))

jaxpr = new_jaxpr()
jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
jtu.check_raises_regexp(
lambda: core.check_jaxpr(jaxpr),
TypeError, ("Jaxpr equation LHS .* is ShapedArray(.*), "
"RHS is inferred as ShapedArray(.*), in '.* = sin .*'"))


if __name__ == '__main__':
absltest.main()

0 comments on commit 96c20f3

Please sign in to comment.