@@ -761,35 +761,24 @@ def split_decodes_prefills_and_extends(
761761
762762 query_lens = query_start_loc [1 :] - query_start_loc [:- 1 ]
763763 is_prefill = query_lens > decode_threshold
764- if not torch .any (is_prefill ):
765- return num_reqs , 0 , 0 , num_tokens , 0 , 0
766-
764+ is_pure_prefill = (seq_lens == query_lens ) & is_prefill
767765 first_prefill = is_prefill .int ().argmax (dim = - 1 ).item ()
768- assert torch .all (query_lens [first_prefill :] > decode_threshold )
769- assert torch .all (query_lens [:first_prefill ] <= decode_threshold )
770-
766+ first_pure_prefill = is_pure_prefill .int ().argmax (dim = - 1 ).item ()
771767 num_decodes = first_prefill
772768 num_decode_tokens = query_start_loc [first_prefill ].item ()
773-
774- query_lens_prefill = query_lens [first_prefill :]
775- seq_lens_prefill = seq_lens [first_prefill :]
776- is_extend = seq_lens_prefill != query_lens_prefill
777-
778- if torch .all (is_extend ):
779- num_extends = num_reqs - num_decodes
780- num_extend_tokens = num_tokens - num_decode_tokens
781- return (num_decodes , num_extends , 0 , num_decode_tokens , num_extend_tokens , 0 )
769+ if not torch .any (is_prefill ):
770+ return (num_decodes , 0 , 0 , num_decode_tokens , 0 , 0 )
782771
783772 num_prefills = num_reqs - num_decodes
784- first_extend = is_extend .int ().argmax (dim = - 1 ).item ()
773+ num_prefill_tokens = num_tokens - num_decode_tokens
774+ if not torch .any (is_pure_prefill ):
775+ return (num_decodes , num_prefills , 0 , num_decode_tokens , num_prefill_tokens , 0 )
785776
786- num_extends = first_extend
787- num_pure_prefills = num_prefills - first_extend
777+ num_extends = first_pure_prefill - num_decodes
778+ num_pure_prefills = num_reqs - first_pure_prefill
788779
789- num_extend_tokens = (
790- query_start_loc [num_extends + num_decodes ].item () - num_decode_tokens
791- )
792- num_pure_prefill_tokens = num_tokens - num_decode_tokens - num_extend_tokens
780+ num_pure_prefill_tokens = num_tokens - query_start_loc [first_pure_prefill ]
781+ num_extend_tokens = num_prefill_tokens - num_pure_prefill_tokens
793782 return (
794783 num_decodes ,
795784 num_extends ,
@@ -875,28 +864,6 @@ def reorder_batch_to_split_decodes_and_prefills(
875864 # NOTE for now we loosely use "decode" to mean requests where attention is
876865 # likely memory-bound and "prefill" to mean requests where attention is
877866 # likely compute-bound,
878- # rid = dist.get_rank()
879- rid = 0
880-
881- def print_order ():
882- if rid == 0 :
883- num_scheduled_tokens = [
884- scheduler_output .num_scheduled_tokens [id ] for id in input_batch .req_ids
885- ]
886- num_scheduled_tokens_np = np .array (num_scheduled_tokens )
887- num_computed_tokens_np = input_batch .num_computed_tokens_cpu [:num_reqs ]
888- print ("num scheduled tokens: " , num_scheduled_tokens_np , flush = True )
889- print ("num computed tokens: " , num_computed_tokens_np , flush = True )
890- is_decode = num_scheduled_tokens_np <= decode_threshold
891- is_extend = (~ is_decode ) & (num_computed_tokens_np > 0 )
892- is_prefill = (~ is_decode ) & (num_computed_tokens_np == 0 )
893- idx = np .arange (0 , is_decode .shape [0 ])
894- decodes = idx [is_decode ]
895- extends = idx [is_extend ]
896- prefills = idx [is_prefill ]
897- print ("decode: " , decodes , flush = True )
898- print ("extends: " , extends , flush = True )
899- print ("prefills: " , prefills , flush = True )
900867
901868 num_reqs = len (input_batch .req_ids )
902869 num_scheduled_tokens = [
0 commit comments