Skip to content

Commit 2f4e230

Browse files
fix format error and add tracker in encoder lengths
1 parent d8eb97d commit 2f4e230

File tree

1 file changed

+27
-16
lines changed

1 file changed

+27
-16
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,7 @@ def __init__(
312312

313313
# mm_hash -> encoder_output
314314
self.encoder_cache: dict[str, torch.Tensor] = {}
315-
self._encoder_cudagraph_buffers: dict[tuple[Any, ...],
316-
torch.Tensor] = {}
315+
self._encoder_cudagraph_buffers: dict[tuple[Any, ...], torch.Tensor] = {}
317316

318317
self.use_aux_hidden_state_outputs = False
319318
# Set up speculative decoding.
@@ -1031,11 +1030,19 @@ def _get_encoder_seq_lens(
10311030
return None
10321031

10331032
# Build encoder_seq_lens array mapping request indices to
1034-
# encoder lengths for inputs scheduled in this batch
1035-
encoder_seq_lens = np.zeros((int(num_reqs), ), dtype=np.int32)
1036-
for req_id in scheduler_output.scheduled_encoder_inputs:
1037-
req_index = self.input_batch.req_id_to_index[req_id]
1038-
encoder_seq_lens[req_index] = self.max_encoder_len
1033+
# encoder lengths for all requests with encoder inputs.
1034+
# Note: This must include ALL requests with encoder features,
1035+
# not just those being scheduled in this step, because cross-attention
1036+
# needs encoder lengths during decode phase for CUDA graph compatibility.
1037+
encoder_seq_lens = np.zeros((int(num_reqs),), dtype=np.int32)
1038+
1039+
# Iterate through all active requests in the batch
1040+
for req_id in self.input_batch.req_ids[:num_reqs]:
1041+
req_state = self.requests.get(req_id)
1042+
# Check if this request has encoder inputs (multimodal features)
1043+
if req_state and req_state.mm_features:
1044+
req_index = self.input_batch.req_id_to_index[req_id]
1045+
encoder_seq_lens[req_index] = self.max_encoder_len
10391046

10401047
return encoder_seq_lens
10411048

@@ -1902,7 +1909,8 @@ def _extract_encoder_inputs(
19021909

19031910
if self._should_use_encoder_cudagraph_buffers():
19041911
encoder_features = self._prepare_encoder_inputs_for_cudagraph(
1905-
encoder_features)
1912+
encoder_features
1913+
)
19061914

19071915
return encoder_features
19081916

@@ -1918,7 +1926,8 @@ def _prepare_encoder_inputs_for_cudagraph(
19181926
return encoder_inputs
19191927

19201928
encoder_inputs["input_features"] = self._copy_to_cudagraph_buffer(
1921-
("input_features", ), input_features)
1929+
("input_features",), input_features
1930+
)
19221931
return encoder_inputs
19231932

19241933
def _copy_to_cudagraph_buffer(
@@ -1933,22 +1942,24 @@ def _copy_to_cudagraph_buffer(
19331942
buffer = torch.empty_like(value)
19341943
self._encoder_cudagraph_buffers[key] = buffer
19351944
else:
1936-
assert (buffer.shape == value.shape
1937-
and buffer.dtype == value.dtype
1938-
and buffer.device == value.device), (
1939-
"CUDAGraph buffer mismatch for encoder inputs.")
1945+
assert (
1946+
buffer.shape == value.shape
1947+
and buffer.dtype == value.dtype
1948+
and buffer.device == value.device
1949+
), "CUDAGraph buffer mismatch for encoder inputs."
19401950
buffer.copy_(value)
19411951
return buffer
19421952

19431953
if isinstance(value, list):
19441954
return [
1945-
self._copy_to_cudagraph_buffer(key_prefix + (idx, ), item)
1955+
self._copy_to_cudagraph_buffer(key_prefix + (idx,), item)
19461956
for idx, item in enumerate(value)
19471957
]
19481958
if isinstance(value, tuple):
19491959
return tuple(
1950-
self._copy_to_cudagraph_buffer(key_prefix + (idx, ), item)
1951-
for idx, item in enumerate(value))
1960+
self._copy_to_cudagraph_buffer(key_prefix + (idx,), item)
1961+
for idx, item in enumerate(value)
1962+
)
19521963

19531964
return value
19541965

0 commit comments

Comments
 (0)