Skip to content

Commit

Permalink
fixed intercept_methods docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Feb 14, 2024
1 parent c25f546 commit dbee3a0
Showing 1 changed file with 37 additions and 37 deletions.
74 changes: 37 additions & 37 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,54 +329,54 @@ def intercept_methods(interceptor: Interceptor):
the underlying method. Or you could completely skip calling the underlying
method and decide to do something differently. For example::
>>> import flax.linen as nn
>>> import jax.numpy as jnp
...
>>> class Foo(nn.Module):
... def __call__(self, x):
... return x
...
>>> def my_interceptor1(next_fun, args, kwargs, context):
... print('calling my_interceptor1')
... return next_fun(*args, **kwargs)
...
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
>>> import flax.linen as nn
>>> import jax.numpy as jnp
...
>>> class Foo(nn.Module):
... def __call__(self, x):
... return x
...
>>> def my_interceptor1(next_fun, args, kwargs, context):
... print('calling my_interceptor1')
... return next_fun(*args, **kwargs)
...
>>> foo = Foo()
>>> with nn.intercept_methods(my_interceptor1):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
You could also register multiple interceptors on the same method. Interceptors
will run in order. For example::
>>> def my_interceptor2(next_fun, args, kwargs, context):
... print('calling my_interceptor2')
... return next_fun(*args, **kwargs)
...
>>> with nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2
>>> def my_interceptor2(next_fun, args, kwargs, context):
... print('calling my_interceptor2')
... return next_fun(*args, **kwargs)
...
>>> with nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor1
calling my_interceptor2
You could skip other interceptors by directly calling the
``context.orig_method``. For example::
>>> def my_interceptor3(next_fun, args, kwargs, context):
... print('calling my_interceptor3')
... return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
... nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor3
>>> def my_interceptor3(next_fun, args, kwargs, context):
... print('calling my_interceptor3')
... return context.orig_method(*args, **kwargs)
>>> with nn.intercept_methods(my_interceptor3), \
... nn.intercept_methods(my_interceptor1), \
... nn.intercept_methods(my_interceptor2):
... _ = foo(jnp.ones([1]))
calling my_interceptor3
The following methods couldn't be intercepted:
1. Methods decoratored with ``nn.nowrap``.
2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__',
and '__post_init__'.
3. Module dataclass fields.
4. Module descriptors.
1. Methods decoratored with ``nn.nowrap``.
2. Dunder methods including '__eq__', '__repr__', '__init__', '__hash__',
and '__post_init__'.
3. Module dataclass fields.
4. Module descriptors.
Args:
interceptor: A method interceptor.
Expand Down

0 comments on commit dbee3a0

Please sign in to comment.