Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix-causal-fa-infer #7200

Merged
merged 2 commits into from
Aug 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,26 +263,29 @@ def _checkpointed_attention_forward(
rotary_pos_emb=None,
relative_position_bias=None,
headscale_tensor=None,
inference_mode=None,
):
"""Forward method with activation checkpointing."""

def custom_forward(*inputs):
if len(inputs) == 7:
if len(inputs) == 8:
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
rotary_pos_emb = inputs[4]
relative_position_bias = inputs[5]
headscale_tensor = inputs[6]
elif len(inputs) == 8:
inference_mode = inputs[7]
elif len(inputs) == 9:
query_layer = inputs[0]
key_layer = inputs[1]
value_layer = inputs[2]
attention_mask = inputs[3]
rotary_pos_emb = (inputs[4], inputs[5])
relative_position_bias = inputs[6]
headscale_tensor = inputs[7]
inference_mode = inputs[8]
else:
raise ValueError('unexpected number of inputs')
output_ = self.core_attention(
Expand All @@ -293,6 +296,7 @@ def custom_forward(*inputs):
rotary_pos_emb=rotary_pos_emb,
relative_position_bias=relative_position_bias,
headscale_tensor=headscale_tensor,
inference_mode=inference_mode,
)
return output_

Expand Down Expand Up @@ -523,6 +527,7 @@ def forward(
rotary_pos_emb=rotary_pos_emb,
relative_position_bias=relative_position_bias,
headscale_tensor=self.head_scale_tensor if self.headscale else None,
inference_mode=inference_max_sequence_len is not None and query_layer.shape[0] == 1,
)
else:
context_layer = self.core_attention(
Expand All @@ -535,6 +540,7 @@ def forward(
rotary_pos_emb=rotary_pos_emb,
relative_position_bias=relative_position_bias,
headscale_tensor=self.head_scale_tensor if self.headscale else None,
inference_mode=inference_max_sequence_len is not None and query_layer.shape[0] == 1,
)

# =================
Expand Down Expand Up @@ -821,6 +827,7 @@ def forward(
rotary_pos_emb=None,
relative_position_bias=None,
headscale_tensor=None,
inference_mode=None,
):
b, np, sq, sk, hn = (
query_layer.size(1),
Expand Down Expand Up @@ -878,7 +885,9 @@ def forward(
# relative_position_bias [b, np, sq, sk]
# context_layer [b, np, sq, hn]
# ==================================================
context_layer = self.attn_fn(query_layer, key_layer, value_layer, attention_mask, relative_position_bias)
context_layer = self.attn_fn(
query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode
)

if headscale_tensor is not None:
context_layer = context_layer * headscale_tensor
Expand All @@ -892,7 +901,7 @@ def forward(

return context_layer

def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias):
def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode):
sq, b, np, hn = query_layer.shape
sk = key_layer.shape[0]

Expand Down Expand Up @@ -948,7 +957,7 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a

return context_layer

def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias):
def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode):
query_layer = rearrange(query_layer, 'sq b np hn -> b sq np hn')
key_layer = rearrange(key_layer, 'sk b np hn -> b sk np hn')
value_layer = rearrange(value_layer, 'sv b np hn -> b sv np hn')
Expand All @@ -960,12 +969,16 @@ def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, a
attention_mask = _cast_if_autocast_enabled(attention_mask)
attention_bias = _cast_if_autocast_enabled(attention_bias)

is_causal = self.attn_mask_type == AttnMaskType.causal and not inference_mode

if attention_bias is not None:
return self.flash_attention_triton(query_layer, key_layer, value_layer, attention_mask, attention_bias,)
return self.flash_attention_triton(
query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal,
)
else:
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask,)
return self.flash_attention_cuda(query_layer, key_layer, value_layer, attention_mask, is_causal,)

def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask):
def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_mask, is_causal):
batch_size, seqlen, nheads, _ = query_layer.shape

# True: attend / False: not attend
Expand All @@ -984,7 +997,7 @@ def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_ma
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_layer, attention_mask_q)
k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_layer, attention_mask_kv)
v, _, _, _ = unpad_input(value_layer, attention_mask_kv)
is_causal = self.attn_mask_type == AttnMaskType.causal and query_layer.shape[1] == key_layer.shape[1]

context_layer = flash_attn_unpadded_func(
q,
k,
Expand All @@ -1004,7 +1017,7 @@ def flash_attention_cuda(self, query_layer, key_layer, value_layer, attention_ma
context_layer = context_layer.permute(0, 2, 1, 3)
return context_layer

def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_mask, attention_bias):
def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, is_causal):
if self.attention_dropout_p > 0.0:
raise NotImplementedError(f'attention_dropout not implemented for flash_attention with attention bias')

Expand All @@ -1024,7 +1037,6 @@ def flash_attention_triton(self, query_layer, key_layer, value_layer, attention_
if attention_bias.shape[3] == attention_mask_kv.shape[3]:
attention_bias = attention_bias.masked_fill(~attention_mask_kv, torch.finfo(query_layer.dtype).min)

is_causal = self.attn_mask_type == AttnMaskType.causal and query_layer.shape[1] == key_layer.shape[1]
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_bias, is_causal,)

# [b, sq, np, hn] -> [b, np, sq, hn]
Expand Down