From ec73cfcf99b3c65b1d2400d2f8174f03335743ce Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 14 Nov 2024 11:23:25 -0800 Subject: [PATCH] [nnx] add checkify --- flax/nnx/__init__.py | 1 + flax/nnx/transforms/transforms.py | 79 +++++++++++++++++++++++++++++-- tests/nnx/transforms_test.py | 19 +++++++- 3 files changed, 94 insertions(+), 5 deletions(-) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index affa691d07..6a27b090f5 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -151,6 +151,7 @@ from .transforms.transforms import eval_shape as eval_shape from .transforms.transforms import cond as cond from .transforms.transforms import switch as switch +from .transforms.transforms import checkify as checkify from .transforms.iteration import while_loop as while_loop from .transforms.iteration import fori_loop as fori_loop from .transforms.iteration import StateAxes as StateAxes diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index b74dd18c30..787bc38958 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -15,12 +15,16 @@ from __future__ import annotations from abc import abstractmethod +import dataclasses import functools import inspect import typing as tp +from jax._src import checkify as checkify_lib + from flax.nnx import ( extract, + graph, ) from flax.nnx.module import Module from flax.nnx.proxy_caller import ( @@ -119,7 +123,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): # ------------------------------- -# eval_shape +# simple transforms # ------------------------------- @@ -140,9 +144,76 @@ def _eval_shape_fn(*args, **kwargs): return extract.from_tree(out) -# ------------------------------- -# cond and switch -# ------------------------------- +@dataclasses.dataclass(eq=False) +class CheckifyFn: + f: tp.Callable[..., tp.Any] + + def __post_init__(self): + functools.update_wrapper(self, self.f) + + def __call__(self, *pure_args, **pure_kwargs): + args, kwargs = extract.from_tree( + (pure_args, pure_kwargs), ctxtag='checkify' + ) + out = self.f(*args, **kwargs) + + args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs)) + pure_args_out, pure_kwargs_out, pure_out = extract.to_tree( + (args, kwargs, out), ctxtag='checkify' + ) + return pure_args_out, pure_kwargs_out, pure_out + +def checkify( + f: tp.Callable[..., checkify_lib.Out], + errors: frozenset[type[checkify_lib.JaxException]] = checkify_lib.user_checks, # type: ignore +) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]: + """Reference-aware version of `jax.experimental.checkify + `_. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> from jax.experimental import checkify + >>> import dataclasses + >>> from flax import nnx + ... + >>> @dataclasses.dataclass + ... class Foo(nnx.Module): + ... a: nnx.Param + ... + >>> @nnx.jit + ... def f(m): + ... y = jnp.sin(m.a.value) # error + ... return m.a + y + ... + >>> m = Foo(a=nnx.Param(jnp.inf)) + >>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m) + >>> # err.throw() + >>> print(err) + Error(nan generated by primitive: sin.) + """ + checkify_fn = checkify_lib.checkify(CheckifyFn(f), errors) + + @functools.wraps(f) + @graph.update_context('checkify') + def jit_wrapper(*args, **kwargs): + pure_args, pure_kwargs = extract.to_tree( + (args, kwargs), + ctxtag='checkify', + ) + error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn( + *pure_args, **pure_kwargs + ) + + args_out, kwargs_out, out = extract.from_tree( + (pure_args_out, pure_kwargs_out, pure_out), + ctxtag='checkify', + ) + + return error, out + + return jit_wrapper # type: ignore @general.split_inputs(ctxtag='cond') diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 4c327e1970..2ec49a9ca9 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -21,11 +21,12 @@ from flax import nnx from flax.nnx.transforms import general import jax -from jax.experimental import mesh_utils +from jax.experimental import mesh_utils, checkify import jax.numpy as jnp import numpy as np + class List(nnx.Module): def __init__(self, items): vars(self).update({str(i): item for i, item in enumerate(items)}) @@ -3024,6 +3025,22 @@ def no_nothing(env: Env): env.step.value, np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32) ) +class TestCheckify(absltest.TestCase): + def test_basic(self): + @dataclasses.dataclass + class Foo(nnx.Module): + a: nnx.Param + + @nnx.jit + def f(m): + y = jnp.sin(m.a.value) # error + return m.a + y + + m = Foo(a=nnx.Param(jnp.inf)) + err, out = nnx.checkify(f, errors=checkify.float_checks)(m) + + with self.assertRaisesRegex(ValueError, 'nan generated by primitive: sin'): + err.throw() if __name__ == '__main__': absltest.main()