Skip to content

Commit

Permalink
Fix flash decoding precision (NVIDIA#7852)
Browse files Browse the repository at this point in the history
* Fix flash decoding precision

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Cheng-Ping Hsieh <chsieh@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
hsiehjackson and pre-commit-ci[bot] authored Nov 6, 2023
1 parent f1a9608 commit 7cf1cc4
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,13 @@ def forward(
and not set_inference_key_value_memory
):
# Mainly used for decoding with sq=1
q = rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn')
k = rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn')
v = rearrange(value_layer, 'sk b np hn -> b sk np hn')
q = _cast_if_autocast_enabled(
rearrange(apply_rotary_pos_emb(query_layer, rotary_pos_emb[0]), 'sq b np hn -> b sq np hn')
)
k = _cast_if_autocast_enabled(
rearrange(apply_rotary_pos_emb(key_layer, rotary_pos_emb[1]), 'sk b np hn -> b sk np hn')
)
v = _cast_if_autocast_enabled(rearrange(value_layer, 'sk b np hn -> b sk np hn'))
context_layer = flash_attn_with_kvcache(
q=q, k_cache=k, v_cache=v, causal=self.attn_mask_type == AttnMaskType.causal,
)
Expand Down

0 comments on commit 7cf1cc4

Please sign in to comment.