Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve custom_jvp/vjp error messages #12611

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 120 additions & 29 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
Expand Down Expand Up @@ -80,6 +80,15 @@ def _stop_gradient(x):
else:
return x

# like the api_util.py function, but also grabs output avals for error checking
@lu.transformation_with_aux
def _flatten_fun_nokwargs(in_tree, *args_flat):
py_args = tree_unflatten(in_tree, args_flat)
ans = yield py_args, {}
ans_flat, ans_tree = tree_flatten(ans)
ans_avals = [core.raise_to_shaped(core.get_aval(x)) for x in ans_flat]
yield ans_flat, (ans_tree, ans_avals)


### JVPs
ReturnValue = TypeVar('ReturnValue')
Expand Down Expand Up @@ -205,9 +214,11 @@ def jvp(primals, tangents):

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.jvp:
msg = "No JVP defined for custom_jvp function {} using defjvp."
raise AttributeError(msg.format(self.__name__))
msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp."
raise AttributeError(msg)
jvp_name = getattr(self.jvp, '__name__', str(self.jvp))
args = _resolve_kwargs(self.fun, args, kwargs)
if self.nondiff_argnums:
nondiff_argnums = set(self.nondiff_argnums)
Expand All @@ -222,10 +233,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
f_, dyn_args = lu.wrap_init(self.fun), args
jvp = lu.wrap_init(self.jvp)
args_flat, in_tree = tree_flatten(dyn_args)
flat_fun, out_tree1 = flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_tree2 = _flatten_jvp(jvp, in_tree)
flat_fun, out_type1 = _flatten_fun_nokwargs(f_, in_tree)
flat_jvp, out_type2 = _flatten_jvp(jvp, primal_name, jvp_name, in_tree,
out_type1)
out_flat = custom_jvp_call_p.bind(flat_fun, flat_jvp, *args_flat)
_, out_tree = lu.merge_linear_aux(out_tree1, out_tree2)
_, (out_tree, _) = lu.merge_linear_aux(out_type1, out_type2)
return tree_unflatten(out_tree, out_flat)

def _add_args(f, extra_args):
Expand All @@ -238,22 +250,59 @@ def _add_args_(extra_args, *args, **kwargs):
yield (yield all_args, kwargs)

@lu.transformation_with_aux
def _flatten_jvp(in_tree, *args):
def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args):
primals_in, tangents_in = split_list(args, [len(args) // 2])
py_primals = tree_unflatten(in_tree, primals_in)
py_tangents = tree_unflatten(in_tree, tangents_in)
pair_out = yield (py_primals, py_tangents), {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = ("Custom JVP rule must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, got {}.")
raise TypeError(msg.format(pair_out))
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) representing "
f"primal and tangent outputs, but got {pair_out}.")
raise TypeError(msg)
py_primals_out, py_tangents_out = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
tangents_out, out_tree2 = tree_flatten(py_tangents_out)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
if out_tree != out_tree2:
msg = ("Custom JVP rule must produce primal and tangent outputs with equal "
"container (pytree) structures, but got {} and {} respectively.")
raise TypeError(msg.format(out_tree, out_tree2))
msg = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce primal and tangent outputs with equal container (pytree) "
f"structures, but got {out_tree} and {out_tree2} respectively.")
raise TypeError(msg)
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular of the same container/pytree structure), but "
"instead the JVP rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom JVP rule {jvp_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal in value to the output of the custom_jvp-decorated function "
f"{primal_name}, "
"and in particular with leaves of the same shape/dtype), but "
"instead the JVP rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_jvp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
# TODO(mattjj): compare primals' tangent types to tangent objects' types
primal_avals_out = [
raise_to_shaped(core.get_aval(x), weak_type=False).strip_named_shape()
Expand All @@ -274,7 +323,7 @@ def _flatten_jvp(in_tree, *args):
f" primal {av1.str_short()} for tangent {av2.str_short()}"
for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2)
raise TypeError(msg.format('\n'.join(disagreements)))
yield primals_out + tangents_out, out_tree
yield primals_out + tangents_out, (out_tree, primal_avals)

class CustomJVPCallPrimitive(core.Primitive):
multiple_results = True
Expand Down Expand Up @@ -472,9 +521,11 @@ def f_bwd(res, g):

@traceback_util.api_boundary
def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable=invalid-annotation
primal_name = getattr(self.fun, '__name__', str(self.fun))
if not self.fwd or not self.bwd:
msg = "No VJP defined for custom_vjp function {} using defvjp."
raise AttributeError(msg.format(self.__name__))
msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp."
raise AttributeError(msg)
fwd_name = getattr(self.fwd, '__name__', str(self.fwd))
args = _resolve_kwargs(self.fun, args, kwargs)
if config.jax_enable_custom_vjp_by_custom_transpose:
if self.nondiff_argnums:
Expand All @@ -497,13 +548,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
fwd, bwd = lu.wrap_init(self.fwd), lu.wrap_init(self.bwd)
args_flat, in_tree = tree_flatten(dyn_args)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
flat_fun, out_tree = flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, in_tree)
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
flat_fwd, out_trees = _flatten_fwd(fwd, primal_name, fwd_name, in_tree,
out_type)
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees)
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
*args_flat, out_trees=out_trees)
fst, aux = lu.merge_linear_aux(out_tree, out_trees)
out_tree = aux if fst else aux[0]
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
return tree_unflatten(out_tree, out_flat)

def _check_for_tracers(x):
Expand All @@ -519,19 +570,59 @@ def _check_for_tracers(x):
raise UnexpectedTracerError(msg)

@lu.transformation_with_aux
def _flatten_fwd(in_tree, *args):
def _flatten_fwd(primal_name, fwd_name, in_tree, maybe_out_type, *args):
py_args = tree_unflatten(in_tree, args)
pair_out = yield py_args, {}
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
msg = ("Custom VJP fwd function must produce a pair (list or tuple of "
"length two) representing primal outputs and residuals (values "
"stored from the forward pass for use on the backward pass), "
"got {}.")
raise TypeError(msg.format(pair_out))
py_outs, res = pair_out
out, out_tree = tree_flatten(py_outs)
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output (equal to those of the "
f"custom_vjp-decorated function {primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
f"instead of a pair the fwd rule {fwd_name} produced {pair_out}.")
raise TypeError(msg)
py_primals_out, res = pair_out
primals_out, out_tree = tree_flatten(py_primals_out)
res, res_tree = tree_flatten(res)
yield res + out, (out_tree, res_tree)
primal_avals = [core.raise_to_shaped(core.get_aval(x)) for x in primals_out]
# If the primal function already ran, check out_tree agreement.
try: out_type_ = maybe_out_type()
except lu.StoreException: out_type_ = None
if out_type_ is not None:
out_tree_, primal_avals_ = out_type_
ty_tree = tree_unflatten(out_tree , [a.str_short() for a in primal_avals])
ty_tree_ = tree_unflatten(out_tree_, [a.str_short() for a in primal_avals_])
if out_tree_ != out_tree:
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
"must produce a pair (list or tuple of length two) where the first "
"element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the "
"second element represents residuals (i.e. values stored from the "
"forward pass for use on the backward pass), but "
"instead the fwd rule output's first element had container/pytree "
"structure:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"container/pytree structure:\n"
f""" {str(ty_tree_).replace("'", "")}.""")
raise TypeError(m)
if not all(map(core.typematch, primal_avals, primal_avals_)):
m = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} must "
"produce a pair (list or tuple of length two) "
"where the first element represents the primal output "
"(equal to the output of the custom_vjp-decorated function "
f"{primal_name}) and the second element represents residuals "
"(i.e. values stored from the forward pass for use on the "
"backward pass), but "
"instead the fwd rule output's first element had shapes/dtypes of:\n"
f""" {str(ty_tree ).replace("'", "")}\n"""
f"while the custom_vjp-decorated function {primal_name} had output "
"shapes/dtypes of:\n"
f""" {str(ty_tree_).replace("'", "")}""")
raise TypeError(m)
yield (*res, *primals_out), (out_tree, res_tree)

@lu.transformation
def _flatten_bwd(in_tree, in_avals, out_trees, *args):
Expand Down
123 changes: 120 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6322,7 +6322,8 @@ def foo_jvp(primals, tangents):
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce primal and tangent outputs "
"Custom JVP rule foo_jvp for function f "
"must produce primal and tangent outputs "
"with equal container (pytree) structures, but got "
"{} and {} respectively.".format(
tree_util.tree_structure((1,)),
Expand Down Expand Up @@ -6367,10 +6368,65 @@ def foo_jvp(primals, tangents):
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, got 1.0"),
"Custom JVP rule foo_jvp for function f "
"must produce a pair (list or tuple of length two) "
"representing primal and tangent outputs, but got 1.0"),
lambda: api.jvp(f, (2.,), (1.,)))

def test_jvp_rule_primal_out_type_doesnt_match_primal_error_message(self):
# https://github.com/lucidrains/flash-attention-jax/issues/7

def scan_apply(f, x):
y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y

@jax.custom_jvp
def f(x):
return x

@f.defjvp
def f_jvp(primals, tangents):
(x,), (xdot,) = primals, tangents
return (x, x), (xdot, xdot)

x = jnp.float32(1.)
self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule f_jvp for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal in value to the output of the "
"custom_jvp-decorated function f, and in particular of the "
"same container/pytree structure), but instead the JVP rule "
"output's first element had container/pytree structure:\n"
" (float32[], float32[])\n"
"while the custom_jvp-decorated function f had output "
"container/pytree structure:\n"
" float32[]."
),
lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,)))

@f.defjvp
def f_jvp2(primals, tangents):
(x,), (xdot,) = primals, tangents
return jnp.zeros((3, *x.shape), x.dtype), xdot

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom JVP rule f_jvp2 for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal in value to the output of the "
"custom_jvp-decorated function f, and in particular "
"with leaves of the same shape/dtype), but instead the JVP rule "
"output's first element had shapes/dtypes of:\n"
" float32[3]\n"
"while the custom_jvp-decorated function f had output shapes/dtypes"
" of:\n"
" float32[]"
),
lambda: jax.jvp(lambda x: scan_apply(f, x), (x,), (x,)))

def test_multiple_rule_invocations(self):
@jax.custom_jvp
def expit(x):
Expand Down Expand Up @@ -7361,6 +7417,67 @@ def foo_bwd(_, g):
with self.assertRaisesRegex(TypeError, "Custom VJP rule .* must produce a tuple"):
api.grad(f)(3.)

def test_fwd_rule_primal_out_type_doesnt_match_primal_error_message(self):
# https://github.com/lucidrains/flash-attention-jax/issues/7

def scan_apply(f, x):
y, _ = jax.lax.scan(lambda x, _: (f(x), None), x, None, length=1)
return y

@jax.custom_vjp
def f(x):
return x

def f_fwd(x):
return (x, x), None

def f_bwd(_, y_bar):
return (y_bar,)

f.defvjp(f_fwd, f_bwd)

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP fwd rule f_fwd for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal to the output of the "
"custom_vjp-decorated function f) and the second element "
"represents residuals (i.e. values stored from the forward "
"pass for use on the backward pass), but instead the fwd rule "
"output's first element had container/pytree structure:\n"
" (float32[], float32[])\n"
"while the custom_vjp-decorated function f had output "
"container/pytree structure:\n"
" float32[]."
),
lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.)))

def f_fwd2(x):
return jnp.zeros((3, *x.shape), x.dtype), None

def f_bwd2(_, y_bar):
return (y_bar,)

f.defvjp(f_fwd2, f_bwd2)

self.assertRaisesRegex(
TypeError,
re.escape(
"Custom VJP fwd rule f_fwd2 for function f must produce a pair "
"(list or tuple of length two) where the first element represents "
"the primal output (equal to the output of the "
"custom_vjp-decorated function f) and the second element "
"represents residuals (i.e. values stored from the forward "
"pass for use on the backward pass), but instead the fwd rule "
"output's first element had shapes/dtypes of:\n"
" float32[3]\n"
"while the custom_vjp-decorated function f had output "
"shapes/dtypes of:\n"
" float32[]"
),
lambda: jax.grad(lambda x: scan_apply(f, x))(jnp.float32(1.)))

def test_issue2511(self):
arr = jnp.ones((5, 2, 2))
foo = lambda x: api.vmap(jnp.linalg.det, (0,))(x)
Expand Down