From 6b5b300625d3317577a61796e3c475931dc85d2b Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:38:00 +0000 Subject: [PATCH 1/2] Add flax.nnx.eval_shape docstring --- flax/nnx/transforms/transforms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index b74dd18c30..371c988bdf 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -138,7 +138,14 @@ def _eval_shape_fn(*args, **kwargs): out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) - + """A "lifted" version of `jax.eval_shape `_ + that can handle `flax.nnx.Module `_ + / graph nodes as arguments. + + Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without + performing any floating point operations (FLOPs) which can be expensive. This can be + useful for performing shape inference, for example. + """ # ------------------------------- # cond and switch From ddd7847beeee4245f885ca7af2d19dd8ca03bbca Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 22:53:56 +0000 Subject: [PATCH 2/2] Add flax.nnx.eval_shape docstring --- flax/nnx/transforms/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flax/nnx/transforms/transforms.py b/flax/nnx/transforms/transforms.py index 371c988bdf..03eb91c4e2 100644 --- a/flax/nnx/transforms/transforms.py +++ b/flax/nnx/transforms/transforms.py @@ -139,12 +139,12 @@ def _eval_shape_fn(*args, **kwargs): out = jax.eval_shape(_eval_shape_fn, *args, **kwargs) return extract.from_tree(out) """A "lifted" version of `jax.eval_shape `_ - that can handle `flax.nnx.Module `_ - / graph nodes as arguments. + that can handle `flax.nnx.Module `_ + / graph nodes as arguments. Similar to ``jax.eval_shape``, it computes the shape/dtype of a function `f` without - performing any floating point operations (FLOPs) which can be expensive. This can be - useful for performing shape inference, for example. + performing any floating point operations (FLOPs) which can be expensive. This can be + useful for performing shape inference, for example. """ # -------------------------------