@@ -312,8 +312,7 @@ def __init__(
312312
313313 # mm_hash -> encoder_output
314314 self .encoder_cache : dict [str , torch .Tensor ] = {}
315- self ._encoder_cudagraph_buffers : dict [tuple [Any , ...],
316- torch .Tensor ] = {}
315+ self ._encoder_cudagraph_buffers : dict [tuple [Any , ...], torch .Tensor ] = {}
317316
318317 self .use_aux_hidden_state_outputs = False
319318 # Set up speculative decoding.
@@ -1031,11 +1030,19 @@ def _get_encoder_seq_lens(
10311030 return None
10321031
10331032 # Build encoder_seq_lens array mapping request indices to
1034- # encoder lengths for inputs scheduled in this batch
1035- encoder_seq_lens = np .zeros ((int (num_reqs ), ), dtype = np .int32 )
1036- for req_id in scheduler_output .scheduled_encoder_inputs :
1037- req_index = self .input_batch .req_id_to_index [req_id ]
1038- encoder_seq_lens [req_index ] = self .max_encoder_len
1033+ # encoder lengths for all requests with encoder inputs.
1034+ # Note: This must include ALL requests with encoder features,
1035+ # not just those being scheduled in this step, because cross-attention
1036+ # needs encoder lengths during decode phase for CUDA graph compatibility.
1037+ encoder_seq_lens = np .zeros ((int (num_reqs ),), dtype = np .int32 )
1038+
1039+ # Iterate through all active requests in the batch
1040+ for req_id in self .input_batch .req_ids [:num_reqs ]:
1041+ req_state = self .requests .get (req_id )
1042+ # Check if this request has encoder inputs (multimodal features)
1043+ if req_state and req_state .mm_features :
1044+ req_index = self .input_batch .req_id_to_index [req_id ]
1045+ encoder_seq_lens [req_index ] = self .max_encoder_len
10391046
10401047 return encoder_seq_lens
10411048
@@ -1902,7 +1909,8 @@ def _extract_encoder_inputs(
19021909
19031910 if self ._should_use_encoder_cudagraph_buffers ():
19041911 encoder_features = self ._prepare_encoder_inputs_for_cudagraph (
1905- encoder_features )
1912+ encoder_features
1913+ )
19061914
19071915 return encoder_features
19081916
@@ -1918,7 +1926,8 @@ def _prepare_encoder_inputs_for_cudagraph(
19181926 return encoder_inputs
19191927
19201928 encoder_inputs ["input_features" ] = self ._copy_to_cudagraph_buffer (
1921- ("input_features" , ), input_features )
1929+ ("input_features" ,), input_features
1930+ )
19221931 return encoder_inputs
19231932
19241933 def _copy_to_cudagraph_buffer (
@@ -1933,22 +1942,24 @@ def _copy_to_cudagraph_buffer(
19331942 buffer = torch .empty_like (value )
19341943 self ._encoder_cudagraph_buffers [key ] = buffer
19351944 else :
1936- assert (buffer .shape == value .shape
1937- and buffer .dtype == value .dtype
1938- and buffer .device == value .device ), (
1939- "CUDAGraph buffer mismatch for encoder inputs." )
1945+ assert (
1946+ buffer .shape == value .shape
1947+ and buffer .dtype == value .dtype
1948+ and buffer .device == value .device
1949+ ), "CUDAGraph buffer mismatch for encoder inputs."
19401950 buffer .copy_ (value )
19411951 return buffer
19421952
19431953 if isinstance (value , list ):
19441954 return [
1945- self ._copy_to_cudagraph_buffer (key_prefix + (idx , ), item )
1955+ self ._copy_to_cudagraph_buffer (key_prefix + (idx ,), item )
19461956 for idx , item in enumerate (value )
19471957 ]
19481958 if isinstance (value , tuple ):
19491959 return tuple (
1950- self ._copy_to_cudagraph_buffer (key_prefix + (idx , ), item )
1951- for idx , item in enumerate (value ))
1960+ self ._copy_to_cudagraph_buffer (key_prefix + (idx ,), item )
1961+ for idx , item in enumerate (value )
1962+ )
19521963
19531964 return value
19541965
0 commit comments