File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -620,7 +620,7 @@ def unified_attention(
620620 num_queries_per_kv = num_query_heads // num_kv_heads
621621 head_size = q .shape [2 ]
622622
623- BLOCK_M = 64 if triton . next_power_of_2 ( int ( max_seqlen_q )) > 1 else 16
623+ BLOCK_M = 64 if max_seqlen_q > 1 else 16
624624 BLOCK_Q = BLOCK_M // num_queries_per_kv # for 3d
625625
626626 # Ideally we would launch with kernel with:
@@ -637,7 +637,7 @@ def unified_attention(
637637 # if batch contains a prefill
638638 if max_seqlen_q > 1 or total_num_q_blocks * num_kv_heads > 128 :
639639
640- BLOCK_N = 16 if triton . next_power_of_2 ( int ( max_seqlen_k )) < 128 else 64
640+ BLOCK_N = 16 if max_seqlen_k <= 64 else 64
641641
642642 grid = lambda META : (q .shape [0 ] // (META [
643643 'BLOCK_M' ] // num_queries_per_kv ) + num_seqs , num_kv_heads )
You can’t perform that action at this time.
0 commit comments