Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lint flax.nnx.fori_loop docstring #4370

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ def fori_loop(lower: int, upper: int,
init_val: T,
*,
unroll: int | bool | None = None) -> T:
"""NNX transform of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
"""A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.

Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
Expand All @@ -1509,21 +1509,21 @@ def fori_loop(lower: int, upper: int,


Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
lower: An integer representing the loop index lower bound (inclusive).
upper: An integer representing the loop index upper bound (exclusive).
body_fun: a function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type `T`.
init_val: the initial input for body_fun. Must be of type ``T``.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`).
(i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``).
This argument is only applicable if the loop bounds are statically known.

Returns:
Loop value from the final iteration, of type ``T``.
A loop value from the final iteration, of type ``T``.

"""

Expand All @@ -1537,4 +1537,4 @@ def fori_loop(lower: int, upper: int,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
return out
Loading