Skip to content

Commit

Permalink
filter_eval_shape now alwasys has a .out_struct property
Browse files Browse the repository at this point in the history
Previously, this was erroneously skipped when the function lacked any closed-over variables.
In addition, this commit adds eqxi.cached_filter_eval_shape, as that is needed for the above.
  • Loading branch information
patrick-kidger committed Sep 11, 2023
1 parent c269e15 commit cc2df94
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 31 deletions.
91 changes: 60 additions & 31 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from ._custom_types import sentinel
from ._deprecate import deprecated_0_10
from ._doc_utils import doc_remove_args
from ._eval_shape import cached_filter_eval_shape
from ._filters import (
combine,
is_array,
Expand Down Expand Up @@ -416,6 +417,47 @@ def _unflatten(flat_pytree):
_FlatPyTree = tuple[list[_T], PyTreeDef]


def _check_closure_convert_input(self, args, kwargs):
self_in_dynamic_struct = _unflatten(self.in_dynamic_struct)
self_in_static = _unflatten(self.in_static)
in_dynamic, in_static = partition((args, kwargs), is_array)
in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
# `is` because `tree_equal` may return a tracer
if tree_equal(in_dynamic_struct, self_in_dynamic_struct) is not True:
raise ValueError(
"Closure-converted function called with different dynamic arguments to "
"the example arguments provided."
)
if tree_equal(in_static, self_in_static) is not True:
raise ValueError(
"Closure-converted function called with different static arguments to "
"the example arguments provided."
)
return in_dynamic


class _TrivialClosureConvert(Module):
fn: types.FunctionType
in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
in_static: _FlatPyTree[Any] = field(static=True)

@property
def in_struct(self):
dynamic = _unflatten(self.in_dynamic_struct)
static = _unflatten(self.in_static)
return combine(dynamic, static)

@property
def out_struct(self):
args, kwargs = self.in_struct
return cached_filter_eval_shape(self.fn, *args, **kwargs)

def __call__(self, *args, **kwargs):
# unused output
_ = _check_closure_convert_input(self, args, kwargs)
return self.fn(*args, **kwargs)


class _ClosureConvert(Module):
# Important that `jaxpr` be a leaf (and not static), so that it is a tuple element
# when passing through `filter_primitive_bind` and thus visible to
Expand All @@ -440,27 +482,13 @@ def out_struct(self):
return combine(dynamic, static)

def __call__(self, *args, **kwargs):
self_in_dynamic_struct = _unflatten(self.in_dynamic_struct)
self_out_dynamic_struct = _unflatten(self.out_dynamic_struct)
self_in_static = _unflatten(self.in_static)
self_out_static = _unflatten(self.out_static)
in_dynamic, in_static = partition((args, kwargs), is_array)
in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
# `is` because `tree_equal` may return a tracer
if tree_equal(in_dynamic_struct, self_in_dynamic_struct) is not True:
raise ValueError(
"Closure-converted function called with different dynamic arguments to "
"the example arguments provided."
)
if tree_equal(in_static, self_in_static) is not True:
raise ValueError(
"Closure-converted function called with different static arguments to "
"the example arguments provided."
)
in_dynamic = _check_closure_convert_input(self, args, kwargs)
in_dynamic_flat = jtu.tree_leaves(in_dynamic)
out_dynamic_flat = jax.core.eval_jaxpr(
self.jaxpr, self.consts, *in_dynamic_flat
)
self_out_dynamic_struct = _unflatten(self.out_dynamic_struct)
self_out_static = _unflatten(self.out_static)
out_dynamic_struct_flat, out_dynamic_treedef = jtu.tree_flatten(
self_out_dynamic_struct
)
Expand Down Expand Up @@ -509,24 +537,25 @@ def f(x, y):
f(1., 1.)
```
"""
if isinstance(fn, types.FunctionType) and fn.__closure__ is None:
# In this case, it's not possible to have any closed-over tracers.
# Skip jaxpr tracing for efficiency.
return fn
closed_jaxpr, out_dynamic_struct, out_static = filter_make_jaxpr(fn)(
*args, **kwargs
) # pyright: ignore
in_dynamic, in_static = partition((args, kwargs), _is_struct)
in_dynamic_struct = jax.eval_shape(lambda: in_dynamic)
jaxpr = closed_jaxpr.jaxpr
consts = closed_jaxpr.consts
in_dynamic_struct = jtu.tree_flatten(in_dynamic_struct)
out_dynamic_struct = jtu.tree_flatten(out_dynamic_struct)
in_static = jtu.tree_flatten(in_static)
out_static = jtu.tree_flatten(out_static)
closure_converted = _ClosureConvert(
jaxpr, consts, in_dynamic_struct, out_dynamic_struct, in_static, out_static
)
if isinstance(fn, types.FunctionType) and fn.__closure__ is None:
# In this case, it's not possible to have any closed-over tracers.
# Skip jaxpr tracing for efficiency.
closure_converted = _TrivialClosureConvert(fn, in_dynamic_struct, in_static)
else:
closed_jaxpr, out_dynamic_struct, out_static = filter_make_jaxpr(fn)(
*args, **kwargs
) # pyright: ignore
jaxpr = closed_jaxpr.jaxpr
consts = closed_jaxpr.consts
out_dynamic_struct = jtu.tree_flatten(out_dynamic_struct)
out_static = jtu.tree_flatten(out_static)
closure_converted = _ClosureConvert(
jaxpr, consts, in_dynamic_struct, out_dynamic_struct, in_static, out_static
)
closure_converted = cast(Callable[_P, _T], closure_converted)
return closure_converted

Expand Down
25 changes: 25 additions & 0 deletions equinox/_eval_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

import jax
import jax._src.traceback_util as traceback_util
import jax.tree_util as jtu
from jaxtyping import PyTree

from ._caches import internal_lru_caches
from ._filters import combine, is_array, partition
from ._module import Static

Expand Down Expand Up @@ -35,3 +37,26 @@ def _fn(_static, _dynamic):
dynamic, static = partition((fun, args, kwargs), _filter)
dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
return combine(dynamic_out, static_out.value)


def _to_struct(x):
if is_array(x):
return jax.ShapeDtypeStruct(x.shape, x.dtype)
else:
return x


@ft.lru_cache(maxsize=None)
def _cached_filter_eval_shape(leaves, treedef):
fn, args, kwargs = jtu.tree_unflatten(treedef, leaves)
return filter_eval_shape(fn, *args, **kwargs)


internal_lru_caches.append(_cached_filter_eval_shape)


def cached_filter_eval_shape(fn, *args, **kwargs):
tree = jtu.tree_map(_to_struct, (fn, args, kwargs))
leaves, treedef = jtu.tree_flatten(tree)
leaves = tuple(leaves)
return _cached_filter_eval_shape(leaves, treedef)
1 change: 1 addition & 0 deletions equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
branched_error_if as branched_error_if,
error_if as error_if,
)
from .._eval_shape import cached_filter_eval_shape as cached_filter_eval_shape
from .._misc import left_broadcast_to as left_broadcast_to
from .._module import Static as Static
from .._pretty_print import tree_pp as tree_pp
Expand Down
9 changes: 9 additions & 0 deletions tests/test_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,15 @@ def f(x, y):
f(1.0, 1.0)


def test_closure_convert_trivial():
def f(a):
return a + 1

f2 = eqx.filter_closure_convert(f, 1)
f2.out_struct
assert type(f2).__name__ == "_TrivialClosureConvert"


def test_closure_convert_custom_jvp():
@eqx.filter_custom_jvp
def call(f, x):
Expand Down

0 comments on commit cc2df94

Please sign in to comment.