@@ -1055,7 +1055,6 @@ def _dual_chunk_flash_attn_prefill_func(
10551055                    v_states_intra ,
10561056                    softmax_scale = softmax_scale ,
10571057                    causal = True ,
1058-                     block_table = block_table ,
10591058                    stage = "intra" ,
10601059                    vertical_indices = vertical_buffer ,
10611060                    slash_indices = slash_buffer ,
@@ -1070,7 +1069,6 @@ def _dual_chunk_flash_attn_prefill_func(
10701069                    v_states_intra ,
10711070                    softmax_scale = softmax_scale ,
10721071                    causal = True ,
1073-                     block_table = block_table ,
10741072                    stage = "intra" ,
10751073                    vertical_indices = intra_vertical_indices ,
10761074                    slash_indices = intra_slash_indices ,
@@ -1085,7 +1083,6 @@ def _dual_chunk_flash_attn_prefill_func(
10851083                        v_states_succ ,
10861084                        softmax_scale = softmax_scale ,
10871085                        causal = False ,
1088-                         block_table = block_table ,
10891086                        stage = "succ" ,
10901087                        vertical_indices = succ_vertical_buffer ,
10911088                        slash_indices = succ_slash_buffer ,
@@ -1100,7 +1097,6 @@ def _dual_chunk_flash_attn_prefill_func(
11001097                        v_states_succ ,
11011098                        softmax_scale = softmax_scale ,
11021099                        causal = False ,
1103-                         block_table = block_table ,
11041100                        stage = "succ" ,
11051101                        vertical_indices = succ_vertical_indices ,
11061102                        slash_indices = succ_slash_indices ,
@@ -1115,7 +1111,6 @@ def _dual_chunk_flash_attn_prefill_func(
11151111                        v_states_inter ,
11161112                        softmax_scale = softmax_scale ,
11171113                        causal = False ,
1118-                         block_table = block_table ,
11191114                        stage = "inter" ,
11201115                        vertical_indices = inter_vertical_buffer ,
11211116                        slash_indices = inter_slash_buffer ,
@@ -1130,7 +1125,6 @@ def _dual_chunk_flash_attn_prefill_func(
11301125                        v_states_inter ,
11311126                        softmax_scale = softmax_scale ,
11321127                        causal = False ,
1133-                         block_table = block_table ,
11341128                        stage = "inter" ,
11351129                        vertical_indices = inter_vertical_indices ,
11361130                        slash_indices = inter_slash_indices ,
@@ -1151,7 +1145,6 @@ def _do_flash_attn(
11511145        value_states : torch .Tensor ,
11521146        softmax_scale : float ,
11531147        causal : bool  =  True ,
1154-         block_table : torch .Tensor  =  None ,
11551148        max_seqlen_k : Optional [int ] =  None ,
11561149        stage : str  =  "intra" ,
11571150        vertical_indices : Optional [torch .Tensor ] =  None ,
@@ -1230,7 +1223,6 @@ def _do_flash_attn(
12301223                                      device = query_states .device ),
12311224            max_seqlen_k = max_seqlen_k ,
12321225            causal = causal ,
1233-             block_table = block_table .unsqueeze (0 ),
12341226            return_softmax_lse = True ,
12351227        )
12361228        softmax_lse  =  softmax_lse .view (q_len , q_heads , 1 ).transpose (0 ,
0 commit comments