223223from vllm .model_executor .layers .quantization .utils .quant_utils import (
224224 scaled_quantize )
225225from vllm .model_executor .layers .rotary_embedding import RotaryEmbedding
226+ from vllm .platforms import current_platform
226227from vllm .utils import cdiv , round_down
227228
228229try :
@@ -471,43 +472,46 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
471472 common_prefix_len : int ) -> M :
472473 assert self ._num_decodes + self ._num_prefills == num_reqs
473474
475+ # Note(simon): be careful about the CPU <> GPU memory movement in this
476+ # function. We should avoid GPU -> CPU sync as much as possible because
477+ # it blocks on all previous kernels.
474478 device = self .runner .device
475- query_start_loc = self .runner .query_start_loc_cpu [:num_reqs + 1 ].to (
476- device , non_blocking = True )
477- seq_lens = self .runner .seq_lens_cpu [:num_reqs ].to (device ,
478- non_blocking = True )
479479 block_table = (
480480 self .runner .input_batch .block_table .get_device_tensor ()[:num_reqs ])
481+ query_start_loc = self .runner .query_start_loc_cpu [:num_reqs + 1 ].to (
482+ device , non_blocking = True )
481483 slot_mapping = self .runner .slot_mapping_cpu [:num_actual_tokens ].to (
482484 device , non_blocking = True ).long ()
483485 input_positions = self .runner .positions_cpu [:num_actual_tokens ].to (
484486 device , non_blocking = True ).long ()
485487
488+ seq_lens_cpu = self .runner .seq_lens_cpu [:num_reqs ]
489+ seq_lens = seq_lens_cpu .to (device , non_blocking = True )
490+ max_query_len = seq_lens_cpu .max ().item ()
491+
486492 prefill_metadata = None
487493 if self ._num_prefills > 0 :
488494 reqs_start = self ._num_decodes # prefill_start
489495 tokens_start = self ._num_decode_tokens
490496
491497 context_lens_cpu = self .runner .input_batch .\
492498 num_computed_tokens_cpu_tensor [reqs_start :num_reqs ]
493- context_lens = context_lens_cpu .to (device , non_blocking = True )
499+ max_context_len_cpu = context_lens_cpu .max ().item ()
500+ num_prefills_with_context_cpu = (context_lens_cpu > 0 ).sum ().item ()
494501
495502 chunked_context_metadata = None
496503 if self .chunked_prefill_enabled and self ._num_prefills > 0 \
497- and context_lens . max () > 0 :
504+ and max_context_len_cpu > 0 :
498505 # NOTE: it is recommend you read the `Chunked Prefill` section
499506 # in the comment at the top of the file before trying to
500507 # understand the following code
501508
502- num_prefills_with_context = (context_lens > 0 ).sum ().item ()
503-
504509 # currently we allocate an equal amount of workspace for each
505510 # prefill in the batch, we could probably use a more advanced
506511 # algorithm here and allocate more workspace to prefills with
507512 # longer context lengths
508- max_context_chunk = \
509- self .chunked_prefill_workspace_size \
510- // num_prefills_with_context
513+ max_context_chunk = (self .chunked_prefill_workspace_size //
514+ num_prefills_with_context_cpu )
511515
512516 # align max_context_chunk to page_size by rounding down,
513517 # currently the `gather_cache` kernel cannot handle
@@ -516,30 +520,35 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
516520 self .page_size )
517521
518522 assert max_context_chunk > 0
519- num_chunks = cdiv (context_lens . max () , max_context_chunk )
523+ num_chunks = cdiv (max_context_len_cpu , max_context_chunk )
520524
521525 # if `max_context_chunk = 256`, `num_chunks = 3`, and
522526 # `num_prefills_with_context = 4`, create a tensor that looks
523527 # like
524528 # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
529+ # Note(simon): this is done in CPU because of downstream's
530+ # of `to_list`.
525531 chunk_starts = \
526- torch .arange (num_chunks , device = device , dtype = torch .int32 ) \
532+ torch .arange (num_chunks , dtype = torch .int32 ) \
527533 .unsqueeze (1 ).expand (- 1 , self ._num_prefills ) \
528534 * max_context_chunk
529- chunk_ends = torch .min (context_lens .unsqueeze (0 ),
535+ chunk_ends = torch .min (context_lens_cpu .unsqueeze (0 ),
530536 chunk_starts + max_context_chunk )
531537 chunk_seq_lens = (chunk_ends - chunk_starts ).clamp (min = 0 )
532- _chunk_cu_seq_lens = chunk_seq_lens .cumsum (dim = 1 ).to (
533- torch .int32 )
534- zero = torch .zeros (num_chunks ,
535- dtype = torch .int32 ,
536- device = device ).unsqueeze (- 1 )
538+
539+ cu_seq_lens_cpu = torch .zeros (num_chunks ,
540+ self ._num_prefills + 1 ,
541+ dtype = torch .int32 ,
542+ pin_memory = True )
543+ torch .cumsum (chunk_seq_lens ,
544+ dim = 1 ,
545+ out = cu_seq_lens_cpu [:, 1 :],
546+ dtype = torch .int32 )
537547
538548 chunked_context_metadata = \
539549 MLACommonPrefillMetadata .ChunkedContextMetadata (
540- cu_seq_lens = torch .cat (
541- [zero , _chunk_cu_seq_lens ], dim = 1 ),
542- starts = chunk_starts ,
550+ cu_seq_lens = cu_seq_lens_cpu .to (device , non_blocking = True ),
551+ starts = chunk_starts .to (device , non_blocking = True ),
543552 seq_tot = chunk_seq_lens .sum (dim = 1 ).tolist (),
544553 max_seq_lens = chunk_seq_lens .max (dim = 1 ).values .tolist (),
545554 workspace = self .chunked_prefill_workspace ,
@@ -553,7 +562,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
553562 block_table = block_table [reqs_start :, ...],
554563 query_start_loc = query_start_loc [reqs_start :] -
555564 query_start_loc [reqs_start ],
556- max_query_len = seq_lens [ reqs_start :]. max (). item () ,
565+ max_query_len = max_query_len ,
557566 chunked_context = chunked_context_metadata ,
558567 )
559568
@@ -629,7 +638,9 @@ def __init__(
629638 # already inside an attention custom op), pull out the forward
630639 # method from the rotary embedding and call it directly
631640 # TODO(lucas): we should probably find a cleaner way to do this
632- self .rotary_emb = rotary_emb ._forward_method
641+ self .rotary_emb = rotary_emb .forward_native
642+ if current_platform .is_cuda ():
643+ self .rotary_emb = rotary_emb .forward_cuda
633644
634645 self .q_proj = q_proj
635646 self .kv_b_proj = kv_b_proj
@@ -1043,17 +1054,20 @@ def forward(
10431054 decode_q_nope = self ._q_proj_and_k_up_proj (decode_hs_or_q_c )
10441055 decode_q_pe = torch .matmul (decode_hs_or_q_c , self .W_QR )\
10451056 .view (- 1 , self .num_heads , self .qk_rope_head_dim )
1057+
10461058 decode_q_pe [...], decode_k_pe [...] = self .rotary_emb (
1047- attn_metadata .decode .input_positions , decode_q_pe , decode_k_pe )
1059+ attn_metadata .decode .input_positions , decode_q_pe .contiguous (),
1060+ decode_k_pe )
10481061
10491062 if has_prefill :
10501063 assert attn_metadata .prefill is not None
10511064 prefill_q = self .q_proj (prefill_hs_or_q_c )[0 ]\
10521065 .view (- 1 , self .num_heads , self .qk_head_dim )
10531066 prefill_q_pe = prefill_q [..., self .qk_nope_head_dim :]
1067+
10541068 prefill_q_pe [...], prefill_k_pe [...] = self .rotary_emb (
1055- attn_metadata .prefill .input_positions , prefill_q_pe ,
1056- prefill_k_pe )
1069+ attn_metadata .prefill .input_positions ,
1070+ prefill_q_pe . contiguous (), prefill_k_pe )
10571071
10581072 # write the latent and rope to kv cache
10591073 if kv_cache .numel () > 0 :
0 commit comments