-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] Add LinearGeneral and MultiHeadAttention #3487
Conversation
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #3487 +/- ##
==========================================
+ Coverage 53.33% 53.81% +0.48%
==========================================
Files 95 98 +3
Lines 11252 11513 +261
==========================================
+ Hits 6001 6196 +195
- Misses 5251 5317 +66 ☔ View full report in Codecov by Sentry. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
warnings.warn( | ||
f'You are passing an array of shape {inputs_v.shape} ' | ||
'to the `inputs_v` arg, when you may have intended ' | ||
'to pass it to the `mask` arg. As of Flax version ' | ||
'0.7.4, the function signature of ' | ||
"MultiHeadAttention's `__call__` method " | ||
'has changed to `__call__(inputs_q, inputs_k=None, ' | ||
'inputs_v=None, *, inputs_kv=None, mask=None, ' | ||
'deterministic=None)`. Use the kwarg `mask` instead. ' | ||
'See https://github.com/google/flax/discussions/3389 ' | ||
'and read the docstring for more information.', | ||
DeprecationWarning, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if the warning message is relevant since the NNX attention layer started with this new call signature of __call__(self, inputs_q, inputs_k, inputs_v, *, inputs_kv, ...)
. But I do think linking the flax discussions page could still be useful for users to give context.
deterministic: if false, the attention weight is masked randomly using | ||
dropout, whereas if true, the attention weights are deterministic. | ||
dropout_rng: optional rng key to pass to the attention layer's dropout | ||
mask. Otherwise, self.make_rng('dropout') is used instead. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.make_rng
is only relevant in Flax, I believe? Or does NNX also use it? Is it equivalent to rngs.dropout()
?
import jax.numpy as jnp | ||
|
||
from flax.experimental import nnx | ||
|
||
|
||
class TestMultiHeadAttention: | ||
def test_basic(self): | ||
module = nnx.MultiHeadAttention(2, 3, 6, rngs=nnx.Rngs(0)) | ||
y = module(jnp.ones((1, 7, 3))) | ||
assert y.shape == (1, 7, 6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we port over tests from tests/linen/linen_attention_test.py
import jax.numpy as jnp | ||
|
||
from flax.experimental import nnx | ||
|
||
|
||
class TestLinearGeneral: | ||
def test_basic(self): | ||
module = nnx.LinearGeneral(2, 3, rngs=nnx.Rngs(0)) | ||
y = module(jnp.ones((1, 2))) | ||
|
||
assert y.shape == (1, 3) | ||
assert module.kernel.shape == (2, 3) | ||
assert module.bias is not None | ||
assert module.bias.shape == (3,) | ||
|
||
def test_basic_multi_features(self): | ||
module = nnx.LinearGeneral(2, (3, 4), rngs=nnx.Rngs(0)) | ||
y = module(jnp.ones((1, 2))) | ||
|
||
assert y.shape == (1, 3, 4) | ||
assert module.kernel.shape == (2, 3, 4) | ||
assert module.bias is not None | ||
assert module.bias.shape == (3, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we port over tests from tests/linen/linen_linear_test.py
# DeprecationWarning: pkg_resources is deprecated as an API. | ||
"ignore:.*pkg_resources is deprecated as an API.*:DeprecationWarning", | ||
# DeprecationWarning: Deprecated call to `pkg_resources.declare_namespace('google')`. | ||
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are these warnings coming from?
|
||
Example usage:: | ||
|
||
>>> import flax.linen as nn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this comment seems to point to linen rather than nnx?
What does this PR do?
DenseGeneral
asLinearGeneral
MultiHeadDotProductAttention
asMultiHeadAttention