Skip to content

Commit

Permalink
improve custom_jvp/vjp error messages
Browse files Browse the repository at this point in the history
In particular:
* add function names so it's clear what decorated functions and rules
  are causing the error;
* when possible (because the functions were run), check for agreement of pytree
  structure and leaf shapes/dtypes between the primal function and rules

context: lucidrains/flash-attention-jax#7
  • Loading branch information
mattjj committed Oct 1, 2022
1 parent eb0fa40 commit 430f3d9
Show file tree
Hide file tree
Showing 2 changed files with 239 additions and 32 deletions.
149 changes: 120 additions & 29 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
Expand Down Expand Up @@ -80,6 +80,15 @@ def _stop_gradient(x):
else:
return x

# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux
def _flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)


### JVPs
ReturnValue = TypeVar('ReturnValue')
Expand Down Expand Up @@ -205,9 +214,11 @@ def jvp(primals, tangents):

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.jvp:
msg = "No JVP defined for custom_jvp function {} using defjvp."
raise AttributeError(msg.format(self.__name__))
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
Expand All @@ -222,10 +233,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
f_, dyn_args = lu.wrap_init(self.fun), args
jvp = lu.wrap_init(self.jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
out_type1)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
return tree_unflatten(out_tree, out_flat)

def _add_args(f, extra_args):
Expand All @@ -238,22 +250,59 @@ def _add_args_(extra_args, *args, **kwargs):
yield (yield all_args, kwargs)

@lu.transformation_with_aux
def _flatten_jvp(in_tree, *args):
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
py_primals = tree_unflatten(in_tree, primals_in)
py_tangents = tree_unflatten(in_tree, tangents_in)
pair_out = yield (py_primals, py_tangents), {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = ("Custom JVP rule must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, got {}.")
raise TypeError(msg.format(pair_out))
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) representing "
f"primal and tangent outputs, but got {pair_out}.")
raise TypeError(msg)
py_primals_out, py_tangents_out = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
if out_tree != out_tree2:
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
"container (pytree) structures, but got {} and {} respectively.")
raise TypeError(msg.format(out_tree, out_tree2))
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce primal and tangent outputs with equal container (pytree) "
f"structures, but got {out_tree} and {out_tree2} respectively.")
raise TypeError(msg)
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular of the same container/pytree structure), but "
"instead the JVP rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular with leaves of the same shape/dtype), but "
"instead the JVP rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
Expand All @@ -274,7 +323,7 @@ def _flatten_jvp(in_tree, *args):
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, out_tree
yield primals_out + tangents_out, (out_tree, primal_avals)

class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
Expand Down Expand Up @@ -472,9 +521,11 @@ def f_bwd(res, g):

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.fwd or not self.bwd:
msg = "No VJP defined for custom_vjp function {} using defvjp."
raise AttributeError(msg.format(self.__name__))
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose:
if self.nondiff_argnums:
Expand All @@ -497,13 +548,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)

def _check_for_tracers(x):
Expand All @@ -519,19 +570,59 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)

@lu.transformation_with_aux
def _flatten_fwd(in_tree, *args):
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args):
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = ("Custom VJP fwd function must produce a pair (list or tuple of "
"length two) representing primal outputs and residuals (values "
"stored from the forward pass for use on the backward pass), "
"got {}.")
raise TypeError(msg.format(pair_out))
py_outs, res = pair_out
out, out_tree = tree_flatten(py_outs)
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output (equal to those of the "
f"custom_vjp-decorated function {primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
f"instead of a pair the fwd rule {fwd_name} produced {pair_out}.")
raise TypeError(msg)
py_primals_out, res = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
res, res_tree = tree_flatten(res)
yield res + out, (out_tree, res_tree)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
"instead the fwd rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the second element represents residuals "
"(i.e. values stored from the forward pass for use on the "
"backward pass), but "
"instead the fwd rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
yield (*res, *primals_out), (out_tree, res_tree)

@lu.transformation
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
Expand Down
122 changes: 119 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6322,7 +6322,8 @@ def foo_jvp(primals, tangents):
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce primal and tangent outputs "
"Custom JVP rule foo_jvp for function f "
"must produce primal and tangent outputs "
"with equal container (pytree) structures, but got "
"{} and {} respectively.".format(
tree_util.tree_structure((1,)),
Expand Down Expand Up @@ -6367,10 +6368,64 @@ def foo_jvp(primals, tangents):
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, got 1.0"),
"Custom JVP rule foo_jvp for function f "
"must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, but got 1.0"),
lambda: api.jvp(f, (2.,), (1.,)))

def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self):
# https://github.com/lucidrains/flash-attention-jax/issues/7

def scan_apply(f, x):
y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y

@jax.custom_jvp
def f(x):
return x

@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return (x, x), (xdot, xdot)

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule f_jvp for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal in value to the output of the "
"custom_jvp-decorated function f, and in particular of the "
"same container/pytree structure), but instead the JVP rule "
"output's first element had container/pytree structure:\n"
" (float32[], float32[])\n"
"while the custom_jvp-decorated function f had output "
"container/pytree structure:\n"
" float32[]."
),
lambda: jax.jvp(lambda x: scan_apply(f, x), (1.,), (1.,)))

@f.defjvp
def f_jvp2(primals, tangents):
(x,), (xdot,) = primals, tangents
return jnp.zeros((3, *x.shape), x.dtype), xdot

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule f_jvp for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal in value to the output of the "
"custom_jvp-decorated function f, and in particular "
"with leaves of the same shape/dtype), but instead the JVP rule "
"output's first element had shapes/dtypes of:\n"
" float32[3]\n"
"while the custom_jvp-decorated function f had output shapes/dtypes"
" of:\n"
" float32[]"
),
lambda: jax.jvp(lambda x: scan_apply(f, x), (1.,), (1.,)))

def test_multiple_rule_invocations(self):
@jax.custom_jvp
def expit(x):
Expand Down Expand Up @@ -7361,6 +7416,67 @@ def foo_bwd(_, g):
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
api.grad(f)(3.)

def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self):
# https://github.com/lucidrains/flash-attention-jax/issues/7

def scan_apply(f, x):
y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y

@jax.custom_vjp
def f(x):
return x

def f_fwd(x):
return (x, x), None

def f_bwd(_, y_bar):
return (y_bar,)

f.defvjp(f_fwd, f_bwd)

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP fwd rule f_fwd for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal to the output of the "
"custom_vjp-decorated function f) and the second element "
"represents residuals (i.e. values stored from the forward "
"pass for use on the backward pass), but instead the fwd rule "
"output's first element had container/pytree structure:\n"
" (float32[], float32[])\n"
"while the custom_vjp-decorated function f had output "
"container/pytree structure:\n"
" float32[]."
),
lambda: jax.grad(lambda x: scan_apply(f, x))(1.))

def f_fwd2(x):
return jnp.zeros((3, *x.shape), x.dtype), None

def f_bwd2(_, y_bar):
return (y_bar,)

f.defvjp(f_fwd2, f_bwd2)

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP fwd rule f_fwd for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal to the output of the "
"custom_vjp-decorated function f) and the second element "
"represents residuals (i.e. values stored from the forward "
"pass for use on the backward pass), but instead the fwd rule "
"output's first element had shapes/dtypes of:\n"
" float32[3]\n"
"while the custom_vjp-decorated function f had output "
"shapes/dtypes of:\n"
" float32[]"
),
lambda: jax.grad(lambda x: scan_apply(f, x))(1.))

def test_issue2511(self):
arr = jnp.ones((5, 2, 2))
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
Expand Down

0 comments on commit 430f3d9

Please sign in to comment.