From 430f3d9720611fcfbb6328233cb5903d928339e3 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Thu, 29 Sep 2022 20:01:42 -0700 Subject: [PATCH] improve custom_jvp/vjp error messages In particular: * add function names so it's clear what decorated functions and rules are causing the error; * when possible (because the functions were run), check for agreement of pytree structure and leaf shapes/dtypes between the primal function and rules context: https://github.com/lucidrains/flash-attention-jax/issues/7 --- jax/_src/custom_derivatives.py | 149 ++++++++++++++++++++++++++------- tests/api_test.py | 122 ++++++++++++++++++++++++++- 2 files changed, 239 insertions(+), 32 deletions(-) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 36b05fd8b96b..4365ecf4def6 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -28,7 +28,7 @@ from jax._src import dtypes from jax._src.lax import lax from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable -from jax._src.api_util import flatten_fun_nokwargs, argnums_partial +from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax.core import raise_to_shaped from jax.errors import UnexpectedTracerError from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p @@ -80,6 +80,15 @@ def _stop_gradient(x): else: return x +# like the api_util.py function, but also grabs output avals for error checking +@lu.transformation_with_aux +def _flatten_fun_nokwargs(in_tree, *args_flat): + py_args = tree_unflatten(in_tree, args_flat) + ans = yield py_args, {} + ans_flat, ans_tree = tree_flatten(ans) + ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat] + yield ans_flat, (ans_tree, ans_avals) + ### JVPs ReturnValue = TypeVar('ReturnValue') @@ -205,9 +214,11 @@ def jvp(primals, tangents): @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation + primal_name = getattr(self.fun, '__name__', str(self.fun)) if not self.jvp: - msg = "No JVP defined for custom_jvp function {} using defjvp." - raise AttributeError(msg.format(self.__name__)) + msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." + raise AttributeError(msg) + jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) args = _resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) @@ -222,10 +233,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable f_, dyn_args = lu.wrap_init(self.fun), args jvp = lu.wrap_init(self.jvp) args_flat, in_tree = tree_flatten(dyn_args) - flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree) - flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree) + flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree) + flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree, + out_type1) out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat) - _, out_tree = lu.merge_linear_aux(out_tree1, out_tree2) + _, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2) return tree_unflatten(out_tree, out_flat) def _add_args(f, extra_args): @@ -238,22 +250,59 @@ def _add_args_(extra_args, *args, **kwargs): yield (yield all_args, kwargs) @lu.transformation_with_aux -def _flatten_jvp(in_tree, *args): +def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): primals_in, tangents_in = split_list(args, [len(args) // 2]) py_primals = tree_unflatten(in_tree, primals_in) py_tangents = tree_unflatten(in_tree, tangents_in) pair_out = yield (py_primals, py_tangents), {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: - msg = ("Custom JVP rule must produce a pair (list or tuple of length two) " - "representing primal and tangent outputs, got {}.") - raise TypeError(msg.format(pair_out)) + msg = (f"Custom JVP rule {jvp_name} for function {primal_name} " + "must produce a pair (list or tuple of length two) representing " + f"primal and tangent outputs, but got {pair_out}.") + raise TypeError(msg) py_primals_out, py_tangents_out = pair_out primals_out, out_tree = tree_flatten(py_primals_out) tangents_out, out_tree2 = tree_flatten(py_tangents_out) + primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] if out_tree != out_tree2: - msg = ("Custom JVP rule must produce primal and tangent outputs with equal " - "container (pytree) structures, but got {} and {} respectively.") - raise TypeError(msg.format(out_tree, out_tree2)) + msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must " + "produce primal and tangent outputs with equal container (pytree) " + f"structures, but got {out_tree} and {out_tree2} respectively.") + raise TypeError(msg) + # If the primal function already ran, check out_tree agreement. + try: out_type_ = maybe_out_type() + except lu.StoreException: out_type_ = None + if out_type_ is not None: + out_tree_, primal_avals_ = out_type_ + ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) + ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) + if out_tree_ != out_tree: + m = (f"Custom JVP rule {jvp_name} for function {primal_name} must " + "produce a pair (list or tuple of length two) " + "where the first element represents the primal output " + "(equal in value to the output of the custom_jvp-decorated function " + f"{primal_name}, " + "and in particular of the same container/pytree structure), but " + "instead the JVP rule output's first element had container/pytree " + "structure:\n" + f""" {str(ty_tree ).replace("'", "")}\n""" + f"while the custom_jvp-decorated function {primal_name} had output " + "container/pytree structure:\n" + f""" {str(ty_tree_).replace("'", "")}.""") + raise TypeError(m) + if not all(map(core.typematch, primal_avals, primal_avals_)): + m = (f"Custom JVP rule {jvp_name} for function {primal_name} must " + "produce a pair (list or tuple of length two) " + "where the first element represents the primal output " + "(equal in value to the output of the custom_jvp-decorated function " + f"{primal_name}, " + "and in particular with leaves of the same shape/dtype), but " + "instead the JVP rule output's first element had shapes/dtypes of:\n" + f""" {str(ty_tree ).replace("'", "")}\n""" + f"while the custom_jvp-decorated function {primal_name} had output " + "shapes/dtypes of:\n" + f""" {str(ty_tree_).replace("'", "")}""") + raise TypeError(m) # TODO(mattjj): compare primals' tangent types to tangent objects' types primal_avals_out = [ raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape() @@ -274,7 +323,7 @@ def _flatten_jvp(in_tree, *args): f" primal {av1.str_short()} for tangent {av2.str_short()}" for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2) raise TypeError(msg.format('\n'.join(disagreements))) - yield primals_out + tangents_out, out_tree + yield primals_out + tangents_out, (out_tree, primal_avals) class CustomJVPCallPrimitive(core.Primitive): multiple_results = True @@ -472,9 +521,11 @@ def f_bwd(res, g): @traceback_util.api_boundary def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation + primal_name = getattr(self.fun, '__name__', str(self.fun)) if not self.fwd or not self.bwd: - msg = "No VJP defined for custom_vjp function {} using defvjp." - raise AttributeError(msg.format(self.__name__)) + msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." + raise AttributeError(msg) + fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) args = _resolve_kwargs(self.fun, args, kwargs) if config.jax_enable_custom_vjp_by_custom_transpose: if self.nondiff_argnums: @@ -497,13 +548,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd) args_flat, in_tree = tree_flatten(dyn_args) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] - flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree) - flat_fwd, out_trees = _flatten_fwd(fwd, in_tree) + flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree) + flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree, + out_type) flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees) out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd, *args_flat, out_trees=out_trees) - fst, aux = lu.merge_linear_aux(out_tree, out_trees) - out_tree = aux if fst else aux[0] + _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees) return tree_unflatten(out_tree, out_flat) def _check_for_tracers(x): @@ -519,19 +570,59 @@ def _check_for_tracers(x): raise UnexpectedTracerError(msg) @lu.transformation_with_aux -def _flatten_fwd(in_tree, *args): +def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args): py_args = tree_unflatten(in_tree, args) pair_out = yield py_args, {} if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: - msg = ("Custom VJP fwd function must produce a pair (list or tuple of " - "length two) representing primal outputs and residuals (values " - "stored from the forward pass for use on the backward pass), " - "got {}.") - raise TypeError(msg.format(pair_out)) - py_outs, res = pair_out - out, out_tree = tree_flatten(py_outs) + msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " + "must produce a pair (list or tuple of length two) where the first " + "element represents the primal output (equal to those of the " + f"custom_vjp-decorated function {primal_name}) and the " + "second element represents residuals (i.e. values stored from the " + "forward pass for use on the backward pass), but " + f"instead of a pair the fwd rule {fwd_name} produced {pair_out}.") + raise TypeError(msg) + py_primals_out, res = pair_out + primals_out, out_tree = tree_flatten(py_primals_out) res, res_tree = tree_flatten(res) - yield res + out, (out_tree, res_tree) + primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out] + # If the primal function already ran, check out_tree agreement. + try: out_type_ = maybe_out_type() + except lu.StoreException: out_type_ = None + if out_type_ is not None: + out_tree_, primal_avals_ = out_type_ + ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals]) + ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_]) + if out_tree_ != out_tree: + m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} " + "must produce a pair (list or tuple of length two) where the first " + "element represents the primal output " + "(equal to the output of the custom_vjp-decorated function " + f"{primal_name}) and the " + "second element represents residuals (i.e. values stored from the " + "forward pass for use on the backward pass), but " + "instead the fwd rule output's first element had container/pytree " + "structure:\n" + f""" {str(ty_tree ).replace("'", "")}\n""" + f"while the custom_vjp-decorated function {primal_name} had output " + "container/pytree structure:\n" + f""" {str(ty_tree_).replace("'", "")}.""") + raise TypeError(m) + if not all(map(core.typematch, primal_avals, primal_avals_)): + m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} must " + "produce a pair (list or tuple of length two) " + "where the first element represents the primal output " + "(equal to the output of the custom_vjp-decorated function " + f"{primal_name}) and the second element represents residuals " + "(i.e. values stored from the forward pass for use on the " + "backward pass), but " + "instead the fwd rule output's first element had shapes/dtypes of:\n" + f""" {str(ty_tree ).replace("'", "")}\n""" + f"while the custom_vjp-decorated function {primal_name} had output " + "shapes/dtypes of:\n" + f""" {str(ty_tree_).replace("'", "")}""") + raise TypeError(m) + yield (*res, *primals_out), (out_tree, res_tree) @lu.transformation def _flatten_bwd(in_tree, in_avals, out_trees, *args): diff --git a/tests/api_test.py b/tests/api_test.py index 7b17be1a0fd6..e2df60ddf053 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -6322,7 +6322,8 @@ def foo_jvp(primals, tangents): self.assertRaisesRegex( TypeError, re.escape( - "Custom JVP rule must produce primal and tangent outputs " + "Custom JVP rule foo_jvp for function f " + "must produce primal and tangent outputs " "with equal container (pytree) structures, but got " "{} and {} respectively.".format( tree_util.tree_structure((1,)), @@ -6367,10 +6368,64 @@ def foo_jvp(primals, tangents): self.assertRaisesRegex( TypeError, re.escape( - "Custom JVP rule must produce a pair (list or tuple of length two) " - "representing primal and tangent outputs, got 1.0"), + "Custom JVP rule foo_jvp for function f " + "must produce a pair (list or tuple of length two) " + "representing primal and tangent outputs, but got 1.0"), lambda: api.jvp(f, (2.,), (1.,))) + def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_jvp + def f(x): + return x + + @f.defjvp + def f_jvp(primals, tangents): + (x,), (xdot,) = primals, tangents + return (x, x), (xdot, xdot) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular of the " + "same container/pytree structure), but instead the JVP rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_jvp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (1.,), (1.,))) + + @f.defjvp + def f_jvp2(primals, tangents): + (x,), (xdot,) = primals, tangents + return jnp.zeros((3, *x.shape), x.dtype), xdot + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom JVP rule f_jvp for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal in value to the output of the " + "custom_jvp-decorated function f, and in particular " + "with leaves of the same shape/dtype), but instead the JVP rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_jvp-decorated function f had output shapes/dtypes" + " of:\n" + " float32[]" + ), + lambda: jax.jvp(lambda x: scan_apply(f, x), (1.,), (1.,))) + def test_multiple_rule_invocations(self): @jax.custom_jvp def expit(x): @@ -7361,6 +7416,67 @@ def foo_bwd(_, g): with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"): api.grad(f)(3.) + def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self): + # https://github.com/lucidrains/flash-attention-jax/issues/7 + + def scan_apply(f, x): + y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1) + return y + + @jax.custom_vjp + def f(x): + return x + + def f_fwd(x): + return (x, x), None + + def f_bwd(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd, f_bwd) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had container/pytree structure:\n" + " (float32[], float32[])\n" + "while the custom_vjp-decorated function f had output " + "container/pytree structure:\n" + " float32[]." + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(1.)) + + def f_fwd2(x): + return jnp.zeros((3, *x.shape), x.dtype), None + + def f_bwd2(_, y_bar): + return (y_bar,) + + f.defvjp(f_fwd2, f_bwd2) + + self.assertRaisesRegex( + TypeError, + re.escape( + "Custom VJP fwd rule f_fwd for function f must produce a pair " + "(list or tuple of length two) where the first element represents " + "the primal output (equal to the output of the " + "custom_vjp-decorated function f) and the second element " + "represents residuals (i.e. values stored from the forward " + "pass for use on the backward pass), but instead the fwd rule " + "output's first element had shapes/dtypes of:\n" + " float32[3]\n" + "while the custom_vjp-decorated function f had output " + "shapes/dtypes of:\n" + " float32[]" + ), + lambda: jax.grad(lambda x: scan_apply(f, x))(1.)) + def test_issue2511(self): arr = jnp.ones((5, 2, 2)) foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)