Skip to content

Commit

Permalink
Lint flax.nnx.fori_loop docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 11, 2024
1 parent d31f290 commit 84fa22e
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ def fori_loop(lower: int, upper: int,
init_val: T,
*,
unroll: int | bool | None = None) -> T:
"""NNX transform of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
"""A Flax NNX transformation of `jax.lax.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
Expand All @@ -1509,21 +1509,21 @@ def fori_loop(lower: int, upper: int,
Args:
lower: an integer representing the loop index lower bound (inclusive)
upper: an integer representing the loop index upper bound (exclusive)
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
lower: An integer representing the loop index lower bound (inclusive).
upper: An integer representing the loop index upper bound (exclusive).
body_fun: a function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for body_fun. Must be of type `T`.
init_val: the initial input for body_fun. Must be of type ``T``.
unroll: An optional integer or boolean that determines how much to unroll
the loop. If an integer is provided, it determines how many unrolled
loop iterations to run within a single rolled iteration of the loop. If a
boolean is provided, it will determine if the loop is competely unrolled
(i.e. `unroll=True`) or left completely unrolled (i.e. `unroll=False`).
(i.e. ``unroll=True``) or left completely unrolled (i.e. ``unroll=False``).
This argument is only applicable if the loop bounds are statically known.
Returns:
Loop value from the final iteration, of type ``T``.
A loop value from the final iteration, of type ``T``.
"""

Expand All @@ -1537,4 +1537,4 @@ def fori_loop(lower: int, upper: int,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
return out

0 comments on commit 84fa22e

Please sign in to comment.