-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added sow attention weights #3529
Conversation
👍 |
Just a small comment, can we give a more descriptive name to the |
would |
flax/linen/attention.py
Outdated
@@ -76,6 +77,10 @@ def dot_product_attention_weights( | |||
dtype: the dtype of the computation (default: infer from inputs and params) | |||
precision: numerical precision of the computation see `jax.lax.Precision` | |||
for details. | |||
mdl: if not None, the attention weights are sowed into the 'intermediates' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe explain that this is the module
used to sow the attention weights (if given).
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## main #3529 +/- ##
==========================================
- Coverage 56.16% 56.04% -0.13%
==========================================
Files 100 100
Lines 11861 11865 +4
==========================================
- Hits 6662 6650 -12
- Misses 5199 5215 +16 ☔ View full report in Codecov by Sentry. |
Hi @chiamp , may I know whether we could change to the following alternative:
I am asking because I just realized that intermediates in For example, if we have nested modules, which are all specified with Here is a minimal example where my goal is to use from jax.nn import initializers
from flax import linen as nn
class SubModel(nn.Module):
attention_kwargs: dict
@nn.compact
def __call__(self, x, return_weights=False):
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
x, return_weights=return_weights
)
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x)
x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
x, return_weights=return_weights
)
return x
class Model(nn.Module):
@nn.compact
def __call__(self, x, return_weights=False):
x = SubModel(
dict(
num_heads=8,
qkv_features=16,
kernel_init=initializers.ones,
bias_init=initializers.zeros,
deterministic=False,
)
)(x, return_weights=return_weights)
return x
rng = random.key(0)
x = jnp.ones((4, 6, 5))
module = Model()
v = module.init(rng, x)
_, intermediates = module.apply(
v, x, mutable=['intermediates'], return_weights=True
)
print(intermediates['intermediates']['SubModel_0']['MultiHeadDotProductAttention_0']['attention_weights'][0].shape) However, if we can directly return a tuple, then any manipulation of the attention weights is straightforward. I know that returning a tuple is something quite easy to implement but is unfortunately a break change. I am wondering whether there are some better solutions to this. Or maybe I just misunderstand |
hi @Xiaoming-Zhao, would accessing the sowed intermediates via
|
Resolves #3530 using @JyChang012's implementation of sowing attention weights; this behavior can be configured at call time. An alternative option would be to return the weights as a tuple, as how Pytorch and Tensorflow does it