Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

sliding window self-attention cell #1395

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

ZiyueHuang
Copy link
Member

Description

The AttentionCell for the sliding window self-attention, including the support for multi-headed dilation and the causal attention mode, described in Longformer: The Long-Document Transformer.

cc @sxjscience @szhengac

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

cc @dmlc/gluon-nlp-team

@ZiyueHuang ZiyueHuang requested a review from a team as a code owner October 20, 2020 12:20
@ZiyueHuang
Copy link
Member Author

Waiting for apache/mxnet#19387 to be merged.

@github-actions
Copy link

@sxjscience
Copy link
Member

Is it possible for us to revise the interface to be similar to https://www.deepspeed.ai/tutorials/sparse-attention/?

@github-actions
Copy link

@ZiyueHuang
Copy link
Member Author

benchmark script



import numpy as np
from numpy.testing import assert_allclose
import mxnet as mx
from gluonnlp.attention_cell import masked_softmax, MultiHeadAttentionCell, MultiHeadSlidingWindowAttentionCell
import time

def test_multi_head_sliding_window_dot_attention_cell():

    def gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d):
        """Generate sliding_window attention mask for the full attention matrix ( seq_len^2 ).
        """
        mask_np = np.zeros((batch_size, seq_length, seq_length))
        for i in range(seq_length):
            end = (i + 1 + w * d) if symmetric else (i + 1)
            for j in range(i - w * d, end, d):
                if j >= 0 and j < seq_length:
                    mask_np[:, i, j] = 1
        return mask_np


    def test_selfatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
        attn_cell = MultiHeadAttentionCell()
        # Generate the data
        ctx = mx.gpu(0)
        #ctx = mx.cpu()
        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        mask = gen_sliding_window_mask_full(batch_size, seq_length, w, symmetric, d)
        mask = mx.np.array(mask, ctx=ctx, dtype=np.float32)

        query = mx.np.array(query, ctx=ctx, dtype=np.float32)
        key = mx.np.array(key, ctx=ctx, dtype=np.float32)
        value = mx.np.array(value, ctx=ctx, dtype=np.float32)

        query.attach_grad()
        key.attach_grad()
        value.attach_grad()

        mx.npx.waitall()
        tic = time.time()

        with mx.autograd.record():
            out, _ = attn_cell(query, key, value, mask)
            out.backward()

        mx.npx.waitall()
        toc = time.time()

        return (toc - tic)


    def test_swatten(batch_size, seq_length, num_heads, num_head_units, w, symmetric, d):
        sw_attn_cell = MultiHeadSlidingWindowAttentionCell(w, symmetric)
        # Generate the data
        ctx = mx.gpu(0)
        #ctx = mx.cpu()
        query = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        key = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))
        value = np.random.normal(0, 1, (batch_size, seq_length, num_heads, num_head_units))

        query = mx.np.array(query, ctx=ctx, dtype=np.float32)
        key = mx.np.array(key, ctx=ctx, dtype=np.float32)
        value = mx.np.array(value, ctx=ctx, dtype=np.float32)

        query.attach_grad()
        key.attach_grad()
        value.attach_grad()

        dilation = mx.np.zeros((num_heads,))
        dilation[:] = d
        dilation = mx.np.array(dilation, ctx=ctx, dtype=np.int32)
        valid_length = np.zeros((batch_size,))
        valid_length[:] = seq_length
        valid_length = mx.np.array(valid_length, ctx=ctx, dtype=np.int32)

        mx.npx.waitall()
        tic = time.time()

        with mx.autograd.record():
            sw_out, _ = sw_attn_cell(query, key, value, dilation, valid_length)
            sw_out.backward()

        mx.npx.waitall()
        toc = time.time()

        return (toc - tic)

    num_repeat = 5

    for seq_length in [512, 1024, 2048, 4096]:
        dur = 0.
        w = seq_length//8
        for i in range(num_repeat):
            tmp_dur = test_selfatten(1, seq_length, 12, 64, w, True, 1)
            if i > 1:
                dur += tmp_dur
        dur /= 3.
        print('seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))

        dur = 0.
        for i in range(num_repeat):
            tmp_dur = test_swatten(1, seq_length, 12, 64, w, True, 1)
            if i > 1:
                dur += tmp_dur
        dur /= 3.
        print('sliding-window-attention seq_length={}, w={}, time={:.3f}'.format(seq_length, w, dur))


test_multi_head_sliding_window_dot_attention_cell()

@sxjscience
Copy link
Member

Is there any update on this PR?

@szhengac
Copy link
Member

szhengac commented Dec 2, 2020

@sxjscience it seems the error AttributeError: module 'mxnet.ndarray.numpy_extension' has no attribute 'sldwin_atten_score' is due to that the mxnet version is not the latest.

@sxjscience
Copy link
Member

Yes, we can merge the master so that we will retrigger the test.

@sxjscience
Copy link
Member

Do we have update on this? @ZiyueHuang would you have time to rebase the code?

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants