diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 8e6ddaeea..5f946d8b5 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -169,7 +169,7 @@ def flash_attn_split( T.fill(scores_max, -T.infinity(accum_dtype)) loop_range = T.ceildiv((seqlen_kv // num_split), block_N) - T.fill(K_shared, 0) + for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy( K[bid, (seqlen_kv // num_split) * sid +