Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] Add LinearGeneral and MultiHeadAttention #3487

Merged
merged 2 commits into from
Nov 29, 2023
Merged

Conversation

cgarciae
Copy link
Collaborator

What does this PR do?

  • Ports DenseGeneral as LinearGeneral
  • Ports MultiHeadDotProductAttention as MultiHeadAttention

@codecov-commenter
Copy link

codecov-commenter commented Nov 22, 2023

Codecov Report

Attention: 67 lines in your changes are missing coverage. Please review.

Comparison is base (e172c76) 53.33% compared to head (0a1f78a) 53.81%.
Report is 1 commits behind head on main.

Files Patch % Lines
flax/experimental/nnx/nnx/nn/attention.py 63.87% 56 Missing ⚠️
flax/experimental/nnx/nnx/nn/linear.py 87.34% 10 Missing ⚠️
flax/experimental/nnx/nnx/module.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3487      +/-   ##
==========================================
+ Coverage   53.33%   53.81%   +0.48%     
==========================================
  Files          95       98       +3     
  Lines       11252    11513     +261     
==========================================
+ Hits         6001     6196     +195     
- Misses       5251     5317      +66     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Comment on lines 476 to 488
warnings.warn(
f'You are passing an array of shape {inputs_v.shape} '
'to the `inputs_v` arg, when you may have intended '
'to pass it to the `mask` arg. As of Flax version '
'0.7.4, the function signature of '
"MultiHeadAttention's `__call__` method "
'has changed to `__call__(inputs_q, inputs_k=None, '
'inputs_v=None, *, inputs_kv=None, mask=None, '
'deterministic=None)`. Use the kwarg `mask` instead. '
'See https://github.com/google/flax/discussions/3389 '
'and read the docstring for more information.',
DeprecationWarning,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if the warning message is relevant since the NNX attention layer started with this new call signature of __call__(self, inputs_q, inputs_k, inputs_v, *, inputs_kv, ...). But I do think linking the flax discussions page could still be useful for users to give context.

deterministic: if false, the attention weight is masked randomly using
dropout, whereas if true, the attention weights are deterministic.
dropout_rng: optional rng key to pass to the attention layer's dropout
mask. Otherwise, self.make_rng('dropout') is used instead.
Copy link
Collaborator

@chiamp chiamp Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.make_rng is only relevant in Flax, I believe? Or does NNX also use it? Is it equivalent to rngs.dropout()?

Comment on lines +1 to +10
import jax.numpy as jnp

from flax.experimental import nnx


class TestMultiHeadAttention:
def test_basic(self):
module = nnx.MultiHeadAttention(2, 3, 6, rngs=nnx.Rngs(0))
y = module(jnp.ones((1, 7, 3)))
assert y.shape == (1, 7, 6)
Copy link
Collaborator

@chiamp chiamp Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we port over tests from tests/linen/linen_attention_test.py

Comment on lines +1 to +23
import jax.numpy as jnp

from flax.experimental import nnx


class TestLinearGeneral:
def test_basic(self):
module = nnx.LinearGeneral(2, 3, rngs=nnx.Rngs(0))
y = module(jnp.ones((1, 2)))

assert y.shape == (1, 3)
assert module.kernel.shape == (2, 3)
assert module.bias is not None
assert module.bias.shape == (3,)

def test_basic_multi_features(self):
module = nnx.LinearGeneral(2, (3, 4), rngs=nnx.Rngs(0))
y = module(jnp.ones((1, 2)))

assert y.shape == (1, 3, 4)
assert module.kernel.shape == (2, 3, 4)
assert module.bias is not None
assert module.bias.shape == (3, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we port over tests from tests/linen/linen_linear_test.py

Comment on lines +143 to +146
# DeprecationWarning: pkg_resources is deprecated as an API.
"ignore:.*pkg_resources is deprecated as an API.*:DeprecationWarning",
# DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`.
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are these warnings coming from?

@copybara-service copybara-service bot merged commit a572f6a into main Nov 29, 2023
21 checks passed
@copybara-service copybara-service bot deleted the nnx-mha branch November 29, 2023 00:17

Example usage::

>>> import flax.linen as nn
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment seems to point to linen rather than nnx?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants