From bd45cedb62f8ef40552a2c6ddfe948b1e5dab071 Mon Sep 17 00:00:00 2001 From: Cheng-Ping Hsieh Date: Thu, 2 Nov 2023 11:04:13 -0700 Subject: [PATCH 1/2] Fix flash decoding precision Signed-off-by: Cheng-Ping Hsieh --- nemo/collections/nlp/modules/common/megatron/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 5063c4aff8a9..a854264bb784 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -520,9 +520,9 @@ def forward( and not set_inference_key_value_memory ): # Mainly used for decoding with sq=1 - q = rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn') - k = rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn') - v = rearrange(value_layer, 'sk b np hn -> b sk np hn') + q = _cast_if_autocast_enabled(rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn')) + k = _cast_if_autocast_enabled(rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn')) + v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn')) context_layer = flash_attn_with_kvcache( q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal, ) From 24d5536216bdd00f83e40c19f904013bf3ee317b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Nov 2023 18:08:07 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nemo/collections/nlp/modules/common/megatron/attention.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index a854264bb784..09a9251ce46e 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -520,8 +520,12 @@ def forward( and not set_inference_key_value_memory ): # Mainly used for decoding with sq=1 - q = _cast_if_autocast_enabled(rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn')) - k = _cast_if_autocast_enabled(rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn')) + q = _cast_if_autocast_enabled( + rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn') + ) + k = _cast_if_autocast_enabled( + rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn') + ) v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn')) context_layer = flash_attn_with_kvcache( q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,