From c04c1f28cf6488dcb753bdfa0d2e61f3a1da6777 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Wed, 7 Aug 2024 18:25:03 +0200 Subject: [PATCH] Fix missing `out_axes` for transpose-of-vprim. In particular this resulted in a trace error here: https://github.com/patrick-kidger/lineax/issues/101 --- equinox/internal/_primitive.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/equinox/internal/_primitive.py b/equinox/internal/_primitive.py index bdc32d3a..d044901b 100644 --- a/equinox/internal/_primitive.py +++ b/equinox/internal/_primitive.py @@ -6,6 +6,7 @@ import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir +import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import PyTree @@ -404,12 +405,22 @@ def _resolve_undefined_b(input, batch_axis): def _vprim_transpose( cts, *inputs, prim, __axis_size, __axis_name, __batch_axes, params ): - inputs = [_resolve_undefined_i(i, b) for i, b in zip(inputs, __batch_axes)] + mapped_inputs = [_resolve_undefined_i(i, b) for i, b in zip(inputs, __batch_axes)] batch_axes = [_resolve_undefined_b(i, b) for i, b in zip(inputs, __batch_axes)] - transpose = ft.partial(_vprim_transpose_registry[prim], **params) + + def _transpose(*_inputs): + _outputs = _vprim_transpose_registry[prim](*_inputs, **params) + # `Zero` is not a JAX type -- it's an internal AD thing -- so we shouldn't pass + # it across the `vmap` boundary. In particular JAX won't apply the out batch + # axis to it. + # JAX allows for returning `None` to indicate no cotangent, so we use that + # instead, which is compatible with both `vmap` and `out_axes`. + return tuple(None if type(o) is ad.Zero else o for o in _outputs) + transpose = jax.vmap( - transpose, + _transpose, in_axes=(0, *batch_axes), + out_axes=__batch_axes, axis_size=__axis_size, axis_name=__axis_name, ) @@ -417,7 +428,18 @@ def _vprim_transpose( cts = tuple(None if type(c) is ad.Zero else c for c in cts) else: cts = None if type(cts) is ad.Zero else cts - return transpose(cts, *inputs) + outputs = transpose(cts, *mapped_inputs) + assert len(inputs) == len(outputs) + for i, o in zip(inputs, outputs): + if o is not None: + # Can't have cotangents on defined variables I think? The point of an + # `UndefinedPrimal` is to declare what you want cotangents with respect to. + assert type(i) is ad.UndefinedPrimal + # We've filtered out all other avals above, with a `NotImplementedError` if + # required. + assert isinstance(i.aval, jax.core.ShapedArray) + assert i.aval.shape == jnp.shape(o) + return outputs # _vprim_p is itself a vprim!