Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] explicit Variables #3720

Merged
merged 1 commit into from
Mar 1, 2024
Merged

[nnx] explicit Variables #3720

merged 1 commit into from
Mar 1, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Feb 26, 2024

What does this PR do?

  • The Variable.value attribute which contained the unprocessed values is now named .raw_value.
  • Replaces Variable.get_value and Variable.set_value with a .value property.
  • Module getattr and setattr no longer extract the inner value of Variables.
  • Removes Module.variables and State.variables helpers.

Basic example now looks like this:

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w: Param[Array] = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b: Param[Array] = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x: jax.Array):
    return x @ self.w.value + self.b.value

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Feb 26, 2024

Codecov Report

Attention: Patch coverage is 95.66474% with 15 lines in your changes are missing coverage. Please review.

Project coverage is 58.85%. Comparing base (acba0bf) to head (0d347d0).

Files Patch % Lines
flax/experimental/nnx/nnx/variables.py 87.75% 6 Missing ⚠️
flax/experimental/nnx/nnx/spmd.py 57.14% 3 Missing ⚠️
flax/experimental/nnx/nnx/state.py 70.00% 3 Missing ⚠️
flax/experimental/nnx/nnx/nn/linear.py 83.33% 2 Missing ⚠️
flax/experimental/nnx/nnx/nn/normalization.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3720      +/-   ##
==========================================
- Coverage   59.05%   58.85%   -0.21%     
==========================================
  Files         103      103              
  Lines       12440    12318     -122     
==========================================
- Hits         7347     7250      -97     
+ Misses       5093     5068      -25     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@cgarciae cgarciae force-pushed the nnx-explicit-variables branch 2 times, most recently from 37450b8 to 6ad6c96 Compare February 29, 2024 22:03
@cgarciae cgarciae force-pushed the nnx-explicit-variables branch 4 times, most recently from 4d5de7f to 4997d59 Compare March 1, 2024 14:35
@cgarciae cgarciae force-pushed the nnx-explicit-variables branch 2 times, most recently from d1a67cf to 0d347d0 Compare March 1, 2024 15:30
@copybara-service copybara-service bot merged commit 7d6de00 into main Mar 1, 2024
21 checks passed
@copybara-service copybara-service bot deleted the nnx-explicit-variables branch March 1, 2024 22:11
@NeilGirdhar
Copy link
Contributor

Woohoo!!! Awesome!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants