Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add open in colab button to why nnx #3596

Merged
merged 1 commit into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flax/experimental/nnx/docs/why.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
"source": [
"# Why NNX?\n",
"\n",
"<!-- open in colab button -->\n",
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)\n",
"\n",
"Four years ago we developed the Flax \"Linen\" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.\n",
"\n",
"We introduced some ideas that have proven to be good:\n",
Expand All @@ -32,6 +35,7 @@
"metadata": {},
"outputs": [],
"source": [
"! pip install -U git+https://github.com/google/flax.git\n",
"from functools import partial\n",
"import jax\n",
"from jax import random, numpy as jnp\n",
Expand Down
4 changes: 4 additions & 0 deletions flax/experimental/nnx/docs/why.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ jupytext:

# Why NNX?

<!-- open in colab button -->
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)

Four years ago we developed the Flax "Linen" API to support modeling research on JAX, with a focus on scaling scaling and performance. We've learned a lot from our users over these years.

We introduced some ideas that have proven to be good:
Expand All @@ -30,6 +33,7 @@ We'd love to hear from any of our users about their thoughts on these ideas.
[[this doc on github](https://github.com/google/flax/blob/main/flax/experimental/nnx/docs/why.ipynb)]

```{code-cell}
! pip install -U git+https://github.com/google/flax.git
from functools import partial
import jax
from jax import random, numpy as jnp
Expand Down
Loading