@@ -78,9 +78,6 @@ def __init__(
7878 super ().__init__ (
7979 kv_cache_spec , layer_names , vllm_config , device , AiterMLAMetadata
8080 )
81- assert self .kv_cache_spec .block_size == 1 , (
82- "AITER MLAonly supports block size 1."
83- )
8481
8582 self .compilation_config = vllm_config .compilation_config
8683 max_num_pages_per_req = cdiv (
@@ -94,6 +91,11 @@ def __init__(
9491 # so we can only use the persistent buffer if a cudagraph is actually
9592 # being used.
9693 if self .compilation_config .cudagraph_mode .has_full_cudagraphs ():
94+ self .block_table_remapping = torch .zeros (
95+ [max_num_reqs , max_num_pages_per_req * self .kv_cache_spec .block_size ],
96+ dtype = torch .int32 ,
97+ device = device ,
98+ )
9799 self .paged_kv_indptr = torch .zeros (
98100 max_num_reqs + 1 , dtype = torch .int32 , device = device
99101 )
@@ -119,13 +121,29 @@ def _build_decode(
119121 dcp_tot_seq_lens_device : torch .Tensor | None ,
120122 ) -> AiterMLADecodeMetadata :
121123 page_size = self .kv_cache_spec .block_size
122- block_table_bounds = (seq_lens_device + page_size - 1 ) // page_size
123124 device = self .device
124125 num_reqs = seq_lens_device .size (0 )
126+ bs , _ = block_table_tensor .shape
127+ block_table_tensor = (
128+ block_table_tensor .unsqueeze (- 1 ).expand (- 1 , - 1 , page_size ) * page_size
129+ )
130+ block_table_tensor = (
131+ block_table_tensor
132+ + torch .arange (
133+ 0 ,
134+ page_size ,
135+ device = block_table_tensor .device ,
136+ dtype = block_table_tensor .dtype ,
137+ )[None , None , :]
138+ )
139+ block_table_tensor = block_table_tensor .view (bs , - 1 )
125140
141+ # after remapping, we assume the block size already equals to 1
142+
143+ max_blk_size_per_req = block_table_tensor .shape [- 1 ]
126144 mask = torch .arange (
127145 block_table_tensor .size (1 ), dtype = block_table_tensor .dtype , device = device
128- ).unsqueeze (0 ) < block_table_bounds .unsqueeze (1 )
146+ ).unsqueeze (0 ) < seq_lens_device .unsqueeze (1 )
129147 paged_kv_indices = block_table_tensor [mask ]
130148
131149 paged_kv_last_page_len = seq_lens_device % page_size
@@ -135,13 +153,19 @@ def _build_decode(
135153
136154 paged_kv_indptr = torch .cat (
137155 [
138- torch .zeros (1 , dtype = block_table_bounds .dtype , device = device ),
139- block_table_bounds .cumsum (dim = 0 , dtype = torch .int32 ),
156+ torch .zeros (1 , dtype = seq_lens_device .dtype , device = device ),
157+ seq_lens_device .cumsum (dim = 0 , dtype = torch .int32 ),
140158 ]
141159 )
142160
143161 if self .compilation_config .cudagraph_mode .has_full_cudagraphs ():
144162 num_actual_pages = paged_kv_indices .size (0 )
163+ self .block_table_remapping [:num_reqs , :max_blk_size_per_req ].copy_ (
164+ block_table_tensor , non_blocking = True
165+ )
166+ block_table_tensor = self .block_table_remapping [
167+ :num_reqs , :max_blk_size_per_req
168+ ]
145169
146170 self .paged_kv_indices [:num_actual_pages ].copy_ (
147171 paged_kv_indices , non_blocking = True
0 commit comments