diff --git a/flax/linen/module.py b/flax/linen/module.py index 16f5f491ec..ab948ced79 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -329,54 +329,54 @@ def intercept_methods(interceptor: Interceptor): the underlying method. Or you could completely skip calling the underlying method and decide to do something differently. For example:: - >>> import flax.linen as nn - >>> import jax.numpy as jnp - ... - >>> class Foo(nn.Module): - ... def __call__(self, x): - ... return x - ... - >>> def my_interceptor1(next_fun, args, kwargs, context): - ... print('calling my_interceptor1') - ... return next_fun(*args, **kwargs) - ... - >>> foo = Foo() - >>> with nn.intercept_methods(my_interceptor1): - ... _ = foo(jnp.ones([1])) - calling my_interceptor1 + >>> import flax.linen as nn + >>> import jax.numpy as jnp + ... + >>> class Foo(nn.Module): + ... def __call__(self, x): + ... return x + ... + >>> def my_interceptor1(next_fun, args, kwargs, context): + ... print('calling my_interceptor1') + ... return next_fun(*args, **kwargs) + ... + >>> foo = Foo() + >>> with nn.intercept_methods(my_interceptor1): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 You could also register multiple interceptors on the same method. Interceptors will run in order. For example:: - >>> def my_interceptor2(next_fun, args, kwargs, context): - ... print('calling my_interceptor2') - ... return next_fun(*args, **kwargs) - ... - >>> with nn.intercept_methods(my_interceptor1), \ - ... nn.intercept_methods(my_interceptor2): - ... _ = foo(jnp.ones([1])) - calling my_interceptor1 - calling my_interceptor2 + >>> def my_interceptor2(next_fun, args, kwargs, context): + ... print('calling my_interceptor2') + ... return next_fun(*args, **kwargs) + ... + >>> with nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + calling my_interceptor2 You could skip other interceptors by directly calling the ``context.orig_method``. For example:: - >>> def my_interceptor3(next_fun, args, kwargs, context): - ... print('calling my_interceptor3') - ... return context.orig_method(*args, **kwargs) - >>> with nn.intercept_methods(my_interceptor3), \ - ... nn.intercept_methods(my_interceptor1), \ - ... nn.intercept_methods(my_interceptor2): - ... _ = foo(jnp.ones([1])) - calling my_interceptor3 + >>> def my_interceptor3(next_fun, args, kwargs, context): + ... print('calling my_interceptor3') + ... return context.orig_method(*args, **kwargs) + >>> with nn.intercept_methods(my_interceptor3), \ + ... nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor3 The following methods couldn't be intercepted: - 1. Methods decoratored with ``nn.nowrap``. - 2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__', - and '__post_init__'. - 3. Module dataclass fields. - 4. Module descriptors. + 1. Methods decoratored with ``nn.nowrap``. + 2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__', + and '__post_init__'. + 3. Module dataclass fields. + 4. Module descriptors. Args: interceptor: A method interceptor.