@@ -2636,12 +2636,12 @@ def _finalize_nwor_window(
26362636 sampled_token_ids : torch .Tensor | None ,
26372637 ) -> None :
26382638 manager = self ._deferred_write_manager
2639+ debug = getattr (self , "_nwor_debug" , False )
26392640 if not manager .window_active :
2640- if getattr ( self , "_nwor_debug" , False ) :
2641+ if debug :
26412642 logger .debug ("NWOR: Finalize called but window not active" )
26422643 return
26432644
2644- debug = getattr (self , "_nwor_debug" , False )
26452645 if debug :
26462646 logger .debug ("NWOR: Finalizing window" )
26472647 try :
@@ -2709,32 +2709,51 @@ def _compute_nwor_acceptance(
27092709 spec_decode_metadata , sampled_token_ids , total_tokens , work_device
27102710 )
27112711 if mask is not None :
2712- accepted_counts : list [int ] = []
2712+ # Batch all sums to minimize GPU-CPU synchronization
2713+ sum_tensors : list [torch .Tensor | None ] = []
27132714 start = 0
27142715 for draft_count in num_draft_tokens :
27152716 count = int (draft_count )
27162717 if count == 0 :
2717- accepted_counts .append (0 )
2718+ sum_tensors .append (None )
27182719 continue
27192720 slice_view = mask [start : start + count ]
2720- accepted_counts .append (int ( slice_view .sum (). item () ))
2721+ sum_tensors .append (slice_view .sum ())
27212722 start += count
2723+
2724+ # Single sync for all non-zero counts
2725+ valid_sums = [s for s in sum_tensors if s is not None ]
2726+ if valid_sums :
2727+ all_counts_tensor = torch .stack (valid_sums ).cpu ()
2728+ counts_list = all_counts_tensor .tolist ()
2729+ else :
2730+ counts_list = []
2731+
2732+ # Reconstruct accepted_counts with zeros
2733+ accepted_counts : list [int ] = []
2734+ counts_idx = 0
2735+ for s in sum_tensors :
2736+ if s is None :
2737+ accepted_counts .append (0 )
2738+ else :
2739+ accepted_counts .append (int (counts_list [counts_idx ]))
2740+ counts_idx += 1
2741+
27222742 if return_mask and mask .device != target_device :
27232743 mask = mask .to (device = target_device )
27242744 if not return_mask :
27252745 mask = None
27262746 return accepted_counts , mask
27272747
27282748 draft_ids = spec_decode_metadata .draft_token_ids
2729- if draft_ids .device != work_device :
2730- draft_ids = draft_ids .to (device = work_device )
2731- draft_ids = draft_ids .to (dtype = sampled_token_ids .dtype , copy = False )
2749+ # Combine device and dtype conversion in single operation
2750+ draft_ids = draft_ids .to (device = work_device , dtype = sampled_token_ids .dtype , copy = False )
27322751
27332752 if return_mask :
27342753 mask_work = torch .zeros (total_tokens , dtype = torch .bool , device = work_device )
27352754 else :
27362755 mask_work = None
2737- accepted_counts = []
2756+ sum_tensors : list [ torch . Tensor | None ] = []
27382757
27392758 if sampled_token_ids .ndim == 0 :
27402759 zero_counts = [0 for _ in num_draft_tokens ]
@@ -2749,21 +2768,20 @@ def _compute_nwor_acceptance(
27492768 leading = sampled_token_ids .shape [0 ]
27502769 sampled_token_ids = sampled_token_ids .reshape (leading , - 1 )
27512770
2771+ # Hoist device/dtype conversion outside loop (all rows share same device/dtype)
2772+ sampled_token_ids = sampled_token_ids .to (device = work_device , dtype = draft_ids .dtype )
2773+
27522774 start = 0
27532775 for req_idx , draft_count in enumerate (num_draft_tokens ):
27542776 draft_count = int (draft_count )
27552777 if draft_count == 0 :
2756- accepted_counts .append (0 )
2778+ sum_tensors .append (None )
27572779 continue
27582780 end = start + draft_count
27592781 if req_idx >= sampled_token_ids .shape [0 ]:
27602782 row = sampled_token_ids .new_empty ((0 ,), dtype = sampled_token_ids .dtype )
27612783 else :
27622784 row = sampled_token_ids [req_idx ]
2763- if row .device != work_device :
2764- row = row .to (device = work_device )
2765- if row .dtype != draft_ids .dtype :
2766- row = row .to (dtype = draft_ids .dtype )
27672785 if row .ndim == 0 :
27682786 row = row .unsqueeze (0 )
27692787 elif row .ndim > 1 :
@@ -2784,12 +2802,30 @@ def _compute_nwor_acceptance(
27842802
27852803 if mask_work is not None :
27862804 mask_work [start :end ] = prefix_full
2787- accepted_counts .append (int ( prefix_full .sum (). item () ))
2805+ sum_tensors .append (prefix_full .sum ())
27882806 start = end
27892807
27902808 if start != total_tokens :
27912809 return None , None
27922810
2811+ # Batch all sums to minimize GPU-CPU synchronization
2812+ valid_sums = [s for s in sum_tensors if s is not None ]
2813+ if valid_sums :
2814+ all_counts_tensor = torch .stack (valid_sums ).cpu ()
2815+ counts_list = all_counts_tensor .tolist ()
2816+ else :
2817+ counts_list = []
2818+
2819+ # Reconstruct accepted_counts with zeros
2820+ accepted_counts : list [int ] = []
2821+ counts_idx = 0
2822+ for s in sum_tensors :
2823+ if s is None :
2824+ accepted_counts .append (0 )
2825+ else :
2826+ accepted_counts .append (int (counts_list [counts_idx ]))
2827+ counts_idx += 1
2828+
27932829 if not return_mask :
27942830 return accepted_counts , None
27952831 assert mask_work is not None
@@ -2842,10 +2878,8 @@ def _scv_vectorized_mask(
28422878 if draft_ids .device != device :
28432879 draft_ids = draft_ids .to (device = device )
28442880
2845- cu = spec_decode_metadata .cu_num_draft_tokens .to (device = device )
2846- cu_int32 = cu
2847- if cu .dtype != torch .int32 :
2848- cu_int32 = cu .to (torch .int32 )
2881+ # Combine device and dtype conversion in single operation
2882+ cu_int32 = spec_decode_metadata .cu_num_draft_tokens .to (device = device , dtype = torch .int32 )
28492883
28502884 if self ._scv_mode == "graph" and self ._scv_capture_available :
28512885 if not hasattr (torch .cuda , "CUDAGraph" ):
@@ -2856,7 +2890,11 @@ def _scv_vectorized_mask(
28562890 else :
28572891 num_reqs = len (spec_decode_metadata .num_draft_tokens )
28582892 dtype = sampled_token_ids .dtype
2859- cu_tuple = tuple (cu_int32 .cpu ().tolist ())
2893+ # Compute cumulative sum on CPU to avoid GPU->CPU sync
2894+ import itertools
2895+ cu_tuple = tuple (itertools .accumulate (
2896+ [0 ] + list (spec_decode_metadata .num_draft_tokens )
2897+ ))
28602898 key = (
28612899 num_reqs ,
28622900 max_spec_len ,
0 commit comments