Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support flash decoding #7744

Merged
merged 11 commits into from
Oct 24, 2023
1 change: 1 addition & 0 deletions examples/nlp/language_modeling/megatron_gpt_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def main(cfg) -> None:
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
pretrained_cfg.precision = trainer.precision
pretrained_cfg["use_flash_attention"] = cfg.inference.get("use_flash_attention", False)
if pretrained_cfg.get('mcore_gpt', False):
# with dist checkpointing we can use the model parallel config specified by the user
pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size
Expand Down
27 changes: 26 additions & 1 deletion nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@
flash_attn_unpadded_func, flash_attn_func_triton, flash_attn_func = None, None, None
unpad_input, pad_input = None, None

try:
# Flash Attention 2.2
from flash_attn import flash_attn_with_kvcache

except (ImportError, ModuleNotFoundError):

flash_attn_with_kvcache = None

""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
Expand Down Expand Up @@ -144,6 +152,7 @@ def __init__(
self.normalize_attention_scores = normalize_attention_scores
self.position_embedding_type = position_embedding_type
self.multi_query_attention = multi_query_attention
self.use_flash_attention = use_flash_attention

self.megatron_legacy = megatron_legacy
self.dtype = utils_funcs.torch_dtype_from_precision(precision, megatron_amp_O2)
Expand Down Expand Up @@ -503,7 +512,23 @@ def forward(
if get_key_value:
present = (key_layer, value_layer)

if checkpoint_core_attention:
if (
flash_attn_with_kvcache is not None
and self.use_flash_attention
and rotary_pos_emb is not None
and inference_max_sequence_len
hsiehjackson marked this conversation as resolved.
Show resolved Hide resolved
and not set_inference_key_value_memory
):
# Mainly used for decoding with sq=1
q = rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn')
k = rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn')
v = rearrange(value_layer, 'sk b np hn -> b sk np hn')
context_layer = flash_attn_with_kvcache(
q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,
)
context_layer = rearrange(context_layer, 'b sq np hn -> sq b (np hn)')

elif checkpoint_core_attention:
context_layer = self._checkpointed_attention_forward(
query_layer,
key_layer,
Expand Down
Loading