Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added sow attention weights #3529

Merged
merged 1 commit into from
Dec 6, 2023
Merged

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Dec 5, 2023

Resolves #3530 using @JyChang012's implementation of sowing attention weights; this behavior can be configured at call time. An alternative option would be to return the weights as a tuple, as how Pytorch and Tensorflow does it

@chiamp chiamp self-assigned this Dec 5, 2023
@chiamp
Copy link
Collaborator Author

chiamp commented Dec 5, 2023

@JyChang012
Copy link

👍

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 5, 2023

Just a small comment, can we give a more descriptive name to the mdl argument?

@chiamp
Copy link
Collaborator Author

chiamp commented Dec 5, 2023

Just a small comment, can we give a more descriptive name to the mdl argument?

would module work instead

@@ -76,6 +77,10 @@ def dot_product_attention_weights(
dtype: the dtype of the computation (default: infer from inputs and params)
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
mdl: if not None, the attention weights are sowed into the 'intermediates'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe explain that this is the module used to sow the attention weights (if given).

@codecov-commenter
Copy link

codecov-commenter commented Dec 6, 2023

Codecov Report

Attention: 5 lines in your changes are missing coverage. Please review.

Comparison is base (512a6d8) 56.16% compared to head (9d989b0) 56.04%.
Report is 2 commits behind head on main.

Files Patch % Lines
flax/linen/attention.py 0.00% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3529      +/-   ##
==========================================
- Coverage   56.16%   56.04%   -0.13%     
==========================================
  Files         100      100              
  Lines       11861    11865       +4     
==========================================
- Hits         6662     6650      -12     
- Misses       5199     5215      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@copybara-service copybara-service bot merged commit 50cd169 into google:main Dec 6, 2023
19 checks passed
@chiamp chiamp deleted the attention branch December 6, 2023 21:46
@Xiaoming-Zhao
Copy link

Xiaoming-Zhao commented Dec 8, 2023

Hi @chiamp , may I know whether we could change to the following alternative:

An alternative option would be to return the weights as a tuple, as how Pytorch and Tensorflow does it

I am asking because I just realized that intermediates in sow can only extracted at the outer-most call to nn.Module.

For example, if we have nested modules, which are all specified with nn.compact and nn.MultiHeadDotProductAttention is only called within some inner modules instead of the outer-most one. Then, it is not trivial to do any operation on the attentions within those inner modules in the forward pass.

Here is a minimal example where my goal is to use attention_weights in SubModel's call.:

from jax.nn import initializers
from flax import linen as nn


class SubModel(nn.Module):
  attention_kwargs: dict

  @nn.compact
  def __call__(self, x, return_weights=False):
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
      x, return_weights=return_weights
    )
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x)
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
      x, return_weights=return_weights
    )
    return x


class Model(nn.Module):

  @nn.compact
  def __call__(self, x, return_weights=False):
    x = SubModel(
      dict(
        num_heads=8,
        qkv_features=16,
        kernel_init=initializers.ones,
        bias_init=initializers.zeros,
        deterministic=False,
      )
    )(x, return_weights=return_weights)
    return x


rng = random.key(0)
x = jnp.ones((4, 6, 5))

module = Model()

v = module.init(rng, x)
_, intermediates = module.apply(
  v, x, mutable=['intermediates'], return_weights=True
)

print(intermediates['intermediates']['SubModel_0']['MultiHeadDotProductAttention_0']['attention_weights'][0].shape)

However, if we can directly return a tuple, then any manipulation of the attention weights is straightforward.

I know that returning a tuple is something quite easy to implement but is unfortunately a break change. I am wondering whether there are some better solutions to this.

Or maybe I just misunderstand sow and the attention weights could actually be extracted within some inner modules in a nested case. I am happy to learn more about it. Thanks a lot in advance.

@chiamp
Copy link
Collaborator Author

chiamp commented Dec 9, 2023

hi @Xiaoming-Zhao, would accessing the sowed intermediates via self.variables work for you:

class SubModel(nn.Module):
  attention_kwargs: dict

  @nn.compact
  def __call__(self, x, return_weights=False):
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
      x, return_weights=return_weights
    )
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(x)
    x = nn.MultiHeadDotProductAttention(**self.attention_kwargs)(
      x, return_weights=return_weights
    )

    # access intermediates via self.variables
    if return_weights:
      attention_weights_0 = self.variables['intermediates']['MultiHeadDotProductAttention_0']['attention_weights']
      attention_weights_2 = self.variables['intermediates']['MultiHeadDotProductAttention_2']['attention_weights']

    return x


class Model(nn.Module):

  @nn.compact
  def __call__(self, x, return_weights=False):
    x = SubModel(
      dict(
        num_heads=8,
        qkv_features=16,
        kernel_init=initializers.ones,
        bias_init=initializers.zeros,
        deterministic=False,
      )
    )(x, return_weights=return_weights)

    # access intermediates via self.variables
    if return_weights:
      submodel_attention_weights_0 = self.variables['intermediates']['SubModel_0']['MultiHeadDotProductAttention_0']['attention_weights']
      submodel_attention_weights_2 = self.variables['intermediates']['SubModel_0']['MultiHeadDotProductAttention_2']['attention_weights']

    return x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature request: optionally sow attention weight in dot_product_attention
5 participants