@@ -2828,7 +2828,7 @@ def _get_mm_dummy_batch(
28282828 def _dummy_run (
28292829 self ,
28302830 num_tokens : int ,
2831- cudagraph_runtime_mode : CUDAGraphMode = CUDAGraphMode . NONE ,
2831+ cudagraph_runtime_mode : Optional [ CUDAGraphMode ] = None ,
28322832 force_attention : bool = False ,
28332833 uniform_decode : bool = False ,
28342834 allow_microbatching : bool = True ,
@@ -2844,6 +2844,8 @@ def _dummy_run(
28442844 Args:
28452845 num_tokens: Number of tokens to run the dummy forward pass.
28462846 cudagraph_runtime_mode: used to control the behavior.
2847+ - if not set will determine the cudagraph mode based on using
2848+ the self.cudagraph_dispatcher.
28472849 - CUDAGraphMode.NONE: No cudagraph, for warm up and profile run
28482850 - CUDAGraphMode.PIECEWISE: Piecewise cudagraph.
28492851 - CUDAGraphMode.FULL: Full cudagraph, attention metadata is
@@ -2857,7 +2859,7 @@ def _dummy_run(
28572859 (1 token) and prefill (multiple tokens) requests.
28582860 remove_lora: If False, dummy LoRAs are not destroyed after the run
28592861 """
2860- assert cudagraph_runtime_mode in {
2862+ assert cudagraph_runtime_mode is None or cudagraph_runtime_mode in {
28612863 CUDAGraphMode .NONE , CUDAGraphMode .PIECEWISE , CUDAGraphMode .FULL
28622864 }
28632865
@@ -2899,10 +2901,6 @@ def _dummy_run(
28992901 elif uniform_decode :
29002902 assert not create_mixed_batch
29012903 num_reqs = cdiv (num_tokens , max_query_len )
2902- assert num_reqs <= max_num_reqs , \
2903- f"Do not capture num_reqs { num_reqs } > max_num_reqs " \
2904- f"{ max_num_reqs } for uniform batch. Num tokens: " \
2905- f"{ num_tokens } , max_query_len: { max_query_len } "
29062904 num_scheduled_tokens_list = [max_query_len ] * num_reqs
29072905 if num_tokens % max_query_len != 0 :
29082906 num_scheduled_tokens_list [- 1 ] = num_tokens % max_query_len
@@ -3043,18 +3041,20 @@ def _dummy_run(
30433041
30443042 intermediate_tensors = self .sync_and_slice_intermediate_tensors (
30453043 num_tokens , None , False )
3046- if cudagraph_runtime_mode == CUDAGraphMode . NONE :
3047- batch_descriptor = None
3048- else :
3049- # filter out the valid batch descriptor
3050- _cg_mode , batch_descriptor = \
3051- self . cudagraph_dispatcher . dispatch (
3052- BatchDescriptor ( num_tokens = num_tokens ,
3053- uniform_decode = uniform_decode ))
3054- # sanity check
3055- assert cudagraph_runtime_mode == _cg_mode , (
3044+
3045+ # filter out the valid batch descriptor
3046+ _cg_mode , batch_descriptor = self . cudagraph_dispatcher . dispatch (
3047+ BatchDescriptor ( num_tokens = num_tokens ,
3048+ uniform_decode = uniform_decode ))
3049+ if cudagraph_runtime_mode is not None :
3050+ # we allow forcing NONE when the dispatcher disagrees to support
3051+ # warm ups for cudagraph capture
3052+ assert cudagraph_runtime_mode == CUDAGraphMode . NONE or \
3053+ cudagraph_runtime_mode == _cg_mode , (
30563054 f"Cudagraph runtime mode mismatch at dummy_run. "
30573055 f"Expected { _cg_mode } , but got { cudagraph_runtime_mode } ." )
3056+ else :
3057+ cudagraph_runtime_mode = _cg_mode
30583058
30593059 if ubatch_slices is not None :
30603060 num_tokens = num_tokens // 2
0 commit comments