Skip to content

Commit

Permalink
Fix missing out_axes for transpose-of-vprim.
Browse files Browse the repository at this point in the history
In particular this resulted in a trace error here: patrick-kidger/lineax#101
  • Loading branch information
patrick-kidger committed Aug 7, 2024
1 parent 028a6f4 commit c04c1f2
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import PyTree

Expand Down Expand Up @@ -404,20 +405,41 @@ def _resolve_undefined_b(input, batch_axis):
def _vprim_transpose(
cts, *inputs, prim, __axis_size, __axis_name, __batch_axes, params
):
inputs = [_resolve_undefined_i(i, b) for i, b in zip(inputs, __batch_axes)]
mapped_inputs = [_resolve_undefined_i(i, b) for i, b in zip(inputs, __batch_axes)]
batch_axes = [_resolve_undefined_b(i, b) for i, b in zip(inputs, __batch_axes)]
transpose = ft.partial(_vprim_transpose_registry[prim], **params)

def _transpose(*_inputs):
_outputs = _vprim_transpose_registry[prim](*_inputs, **params)
# `Zero` is not a JAX type -- it's an internal AD thing -- so we shouldn't pass
# it across the `vmap` boundary. In particular JAX won't apply the out batch
# axis to it.
# JAX allows for returning `None` to indicate no cotangent, so we use that
# instead, which is compatible with both `vmap` and `out_axes`.
return tuple(None if type(o) is ad.Zero else o for o in _outputs)

transpose = jax.vmap(
transpose,
_transpose,
in_axes=(0, *batch_axes),
out_axes=__batch_axes,
axis_size=__axis_size,
axis_name=__axis_name,
)
if prim.multiple_results:
cts = tuple(None if type(c) is ad.Zero else c for c in cts)
else:
cts = None if type(cts) is ad.Zero else cts
return transpose(cts, *inputs)
outputs = transpose(cts, *mapped_inputs)
assert len(inputs) == len(outputs)
for i, o in zip(inputs, outputs):
if o is not None:
# Can't have cotangents on defined variables I think? The point of an
# `UndefinedPrimal` is to declare what you want cotangents with respect to.
assert type(i) is ad.UndefinedPrimal
# We've filtered out all other avals above, with a `NotImplementedError` if
# required.
assert isinstance(i.aval, jax.core.ShapedArray)
assert i.aval.shape == jnp.shape(o)
return outputs


# _vprim_p is itself a vprim!
Expand Down

0 comments on commit c04c1f2

Please sign in to comment.