Skip to content

Commit

Permalink
[transformer] support flash att by 'torch scaled dot attention' (#2351)
Browse files Browse the repository at this point in the history
* [transformer] support flash att by 'torch scaled dot attention'

* pass ut on  cpu

* pass ut on cpu

* pass ut on cpu

* zero out  pad mask att

* support attention mask bias in encocder

* fix jit and unit test

* sdap in decoder and search
  • Loading branch information
Mddct authored Feb 21, 2024
1 parent 87831da commit 935250b
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 63 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ flake8-pyi==20.5.0
mccabe
pycodestyle==2.6.0
pyflakes==2.2.0
torch==2.1.2
torchaudio==2.1.2
torch>=2.1.2
torchaudio>=2.1.2
tqdm
deepspeed<0.13.0
librosa
Expand Down
50 changes: 0 additions & 50 deletions test/wenet/ssl/w2vbert/test_w2vbert.py

This file was deleted.

111 changes: 111 additions & 0 deletions test/wenet/transformer/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
import pytest
from wenet.transformer.attention import MultiHeadedAttention
from wenet.transformer.encoder_layer import TransformerEncoderLayer
from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
from wenet.utils.class_utils import WENET_ACTIVATION_CLASSES

from wenet.utils.mask import add_optional_chunk_mask, make_non_pad_mask


@pytest.mark.parametrize("args", [
{
"n_feat": 256,
"n_head": 4,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 8,
"dropout_rate": 0.0
},
{
"n_feat": 1280,
"n_head": 20,
"dropout_rate": 0.0
},
{
"n_feat": 512,
"n_head": 4,
"dropout_rate": 0.0
},
])
def test_sdpa(args):
torch.manual_seed(777)
mha_module = MultiHeadedAttention(use_sdpa=False, **args)
torch.manual_seed(777)
mha_module_with_sdpa = MultiHeadedAttention(use_sdpa=True, **args)
mha_module.eval()
mha_module_with_sdpa.eval()

q = torch.rand(10, 100, args['n_feat'])
k = torch.rand(10, 100, args['n_feat'])
v = torch.rand(10, 100, args['n_feat'])
input_lens = torch.tensor([100, 90, 80, 79, 60, 51, 40, 30, 10, 5])
mask = make_non_pad_mask(input_lens).unsqueeze(1)
att_mask = add_optional_chunk_mask(q,
mask,
use_dynamic_chunk=True,
decoding_chunk_size=0,
static_chunk_size=0,
use_dynamic_left_chunk=True,
num_decoding_left_chunks=-1)
output, cache = mha_module(q, k, v, mask=att_mask)

att_mask_bias = (1.0 - att_mask.float()) * torch.finfo(torch.float).min
output_with_sdpa, cache_with_sdpa = mha_module_with_sdpa(
q, k, v, mask=att_mask_bias)
assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
)
assert torch.allclose(cache, cache_with_sdpa)

n_blocks = 12
torch.manual_seed(777)
mha_layers = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=False, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

torch.manual_seed(777)
mha_layers_with_sdpa = [
TransformerEncoderLayer(
args['n_feat'],
MultiHeadedAttention(use_sdpa=True, **args),
PositionwiseFeedForward(
args['n_feat'],
2048,
0.0,
WENET_ACTIVATION_CLASSES['swish'](),
),
0.0,
normalize_before=True,
) for _ in range(n_blocks)
]

for i in range(n_blocks):
output, _, cache, _ = mha_layers[i](q, att_mask, None, mask)
output_with_sdpa, _, cache_with_sdpa, _ = mha_layers_with_sdpa[i](
q, att_mask_bias, None, mask)

assert torch.allclose(
output * mask.transpose(1, 2),
output_with_sdpa * mask.transpose(1, 2),
atol=9e-7,
rtol=9e-4,
)
# assert torch.allclose(cache, cache_with_sdpa)

q = output
24 changes: 21 additions & 3 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True):
key_bias: bool = True,
use_sdpa: bool = False):
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
Expand All @@ -49,6 +50,9 @@ def __init__(self,
self.linear_out = nn.Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

self.use_sdpa = use_sdpa
self.dropout_rate = dropout_rate

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -192,8 +196,22 @@ def forward(
# non-trivial to calculate `next_cache_start` here.
new_cache = torch.cat((k, v), dim=-1)

scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
return self.forward_attention(v, scores, mask), new_cache
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = (output.transpose(1, 2).contiguous().view(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class RelPositionMultiHeadedAttention(MultiHeadedAttention):
Expand Down
18 changes: 14 additions & 4 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
WENET_ATTENTION_CLASSES,
WENET_ACTIVATION_CLASSES,
)
from wenet.utils.common import mask_to_bias
from wenet.utils.mask import (subsequent_mask, make_pad_mask)


Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(
activation_type: str = "relu",
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
):
super().__init__()
attention_dim = encoder_output_size
Expand All @@ -98,10 +100,10 @@ def __init__(
attention_dim,
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias),
self_attention_dropout_rate, key_bias, use_sdpa),
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias) if src_attention else None,
key_bias, use_sdpa) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
dropout_rate, activation),
dropout_rate,
Expand All @@ -111,6 +113,7 @@ def __init__(

self.gradient_checkpointing = gradient_checkpointing
self.tie_word_embedding = tie_word_embedding
self.use_sdpa = use_sdpa

def forward(
self,
Expand Down Expand Up @@ -152,6 +155,10 @@ def forward(
device=tgt_mask.device).unsqueeze(0)
# tgt_mask: (B, L, L)
tgt_mask = tgt_mask & m
if self.use_sdpa:
tgt_mask = mask_to_bias(tgt_mask, tgt.dtype)
memory_mask = mask_to_bias(memory_mask, memory_mask.dtype)

x, _ = self.embed(tgt)
if self.gradient_checkpointing and self.training:
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
Expand Down Expand Up @@ -290,6 +297,7 @@ def __init__(
key_bias: bool = True,
gradient_checkpointing: bool = False,
tie_word_embedding: bool = False,
use_sdpa: bool = False,
):

super().__init__()
Expand All @@ -309,7 +317,8 @@ def __init__(
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)

self.right_decoder = TransformerDecoder(
vocab_size,
Expand All @@ -326,7 +335,8 @@ def __init__(
normalize_before,
key_bias=key_bias,
gradient_checkpointing=gradient_checkpointing,
tie_word_embedding=tie_word_embedding)
tie_word_embedding=tie_word_embedding,
use_sdpa=use_sdpa)

def forward(
self,
Expand Down
14 changes: 11 additions & 3 deletions wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from wenet.utils.mask import make_pad_mask
from wenet.utils.mask import add_optional_chunk_mask
from wenet.utils.common import mask_to_bias


class BaseEncoder(torch.nn.Module):
Expand All @@ -53,6 +54,7 @@ def __init__(
global_cmvn: torch.nn.Module = None,
use_dynamic_left_chunk: bool = False,
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
"""
Args:
Expand Down Expand Up @@ -84,6 +86,7 @@ def __init__(
key_bias: whether use bias in attention.linear_k, False for whisper models.
gradient_checkpointing: rerunning a forward-pass segment for each
checkpointed segment during backward.
use_sdpa: whether to use SDPA, currently only support transformer for now
"""
super().__init__()
self._output_size = output_size
Expand All @@ -103,6 +106,7 @@ def __init__(
self.use_dynamic_chunk = use_dynamic_chunk
self.use_dynamic_left_chunk = use_dynamic_left_chunk
self.gradient_checkpointing = gradient_checkpointing
self.use_sdpa = use_sdpa

def output_size(self) -> int:
return self._output_size
Expand Down Expand Up @@ -149,6 +153,8 @@ def forward(
decoding_chunk_size,
self.static_chunk_size,
num_decoding_left_chunks)
if self.use_sdpa:
chunk_masks = mask_to_bias(chunk_masks, xs.dtype)
if self.gradient_checkpointing and self.training:
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
mask_pad)
Expand Down Expand Up @@ -355,6 +361,7 @@ def __init__(
key_bias: bool = True,
activation_type: str = "relu",
gradient_checkpointing: bool = False,
use_sdpa: bool = False,
):
""" Construct TransformerEncoder
Expand All @@ -365,15 +372,16 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
use_dynamic_left_chunk, gradient_checkpointing,
use_sdpa)
activation = WENET_ACTIVATION_CLASSES[activation_type]()
self.encoders = torch.nn.ModuleList([
TransformerEncoderLayer(
output_size,
WENET_ATTENTION_CLASSES["selfattn"](attention_heads,
output_size,
attention_dropout_rate,
key_bias),
key_bias, use_sdpa),
PositionwiseFeedForward(output_size, linear_units,
dropout_rate, activation),
dropout_rate, normalize_before) for _ in range(num_blocks)
Expand Down Expand Up @@ -433,7 +441,7 @@ def __init__(
positional_dropout_rate, attention_dropout_rate,
input_layer, pos_enc_layer_type, normalize_before,
static_chunk_size, use_dynamic_chunk, global_cmvn,
use_dynamic_left_chunk, gradient_checkpointing)
use_dynamic_left_chunk, gradient_checkpointing, False)
activation = WENET_ACTIVATION_CLASSES[activation_type]()

# self-attention module definition
Expand Down
7 changes: 6 additions & 1 deletion wenet/transformer/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import torch
from torch.nn.utils.rnn import pad_sequence

from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens)
from wenet.utils.common import (add_sos_eos, log_add, add_whisper_tokens,
mask_to_bias)
from wenet.utils.ctc_utils import remove_duplicates_and_blank
from wenet.utils.mask import (make_pad_mask, mask_finished_preds,
mask_finished_scores, subsequent_mask)
Expand Down Expand Up @@ -289,6 +290,8 @@ def attention_beam_search(
]).unsqueeze(1).to(device) # (B*N, 1)
end_flag = torch.zeros_like(scores, dtype=torch.bool, device=device)
cache: Optional[List[torch.Tensor]] = None
if model.decoder.use_sdpa:
encoder_mask = mask_to_bias(encoder_mask, encoder_out.dtype)
# 2. Decoder forward step by step
for i in range(prefix_len, maxlen + 1):
# Stop if all batch and all beam produce eos
Expand All @@ -297,6 +300,8 @@ def attention_beam_search(
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
if model.decoder.use_sdpa:
hyps_mask = mask_to_bias(hyps_mask, encoder_out.dtype)
# logp: (B*N, vocab)
logp, cache = model.decoder.forward_one_step(encoder_out, encoder_mask,
hyps, hyps_mask, cache)
Expand Down
Loading

0 comments on commit 935250b

Please sign in to comment.