Skip to content

Commit ed61cc7

Browse files
committed
[Fix] Properly handle cached mm_inputs on Encode instance
Signed-off-by: LastZhabka <sakhmoldin.mukhammadarif@gmail.com>
1 parent d8c7cc6 commit ed61cc7

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

vllm/separated_encode/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,9 @@ Separate EncoderScheduler class implementation is provided for encode instance s
220220

221221
The EncoderScheduler is a specialized scheduler for encode instances that focuses on only multimodal input scheduling. It maintains an `_allocated` dictionary to track allocated encoder cache entries, their sizes and hashes. This dictionary is used to allow us to free up logical space without storing the request itself, which enables us to end the request before the data is transferred.
222222

223-
Currently the encode scheduler schedules all multimodal inputs for a request at once in the `schedule()` method. It checks if there's sufficient encoder cache space and budget before allocating all inputs together. A request on the encode instance is considered finished when all its multimodal embeddings have been computed, so all requests are finished in 1 iteration after scheduling, transfer is handled separately in encoder cache connectors, space allocated for encoder cache is deallocated only after transfers, not after request finish.
223+
Currently the encode scheduler schedules all multimodal inputs for a request at once in the `schedule()` method. It checks if there's sufficient encoder cache space and budget before allocating all inputs together. Note that input is already cached we will still add it into the `scheduled_encoder_inputs`, but we will not allocate space for it and on model runner we will skip the encoder execution for such elements, we need to do that because in `model_runner` the signal needs to be sent to `ECConnector` from each `mm_input`.
224+
225+
A request on the encode instance is considered finished when all its multimodal embeddings have been computed, so all requests are finished in 1 iteration after scheduling, transfer is handled separately in encoder cache connectors, space allocated for encoder cache is deallocated only after transfers, not after request finish.
224226

225227
In the `update_from_output()` method, the scheduler goes through transferred multimodal data IDs and frees the mm inputs in encoder cache manager.
226228

@@ -250,7 +252,9 @@ This wrapper runs on encode instances and processes multimodal inputs. It execut
250252

251253
The encode instance doesn't need KV cache since it only runs vision part of MLLM. The wrapper overrides `initialize_kv_cache_tensors` and `initialize_kv_cache` to return empty results, freeing up GPU memory for larger encoder cache storage.
252254

253-
During execution, the wrapper executes encoding for scheduled multimodal inputs and inserts enocder output in encoder cache connector. Since no text generation happens here, it returns empty ModelRunnerOutput with additional transfered encoder outputs information in ModelRunnerOutput, this information is used in encoder scheduler to free the space in encoder cache manager.
255+
During execution, the wrapper executes encoding for scheduled multimodal inputs and inserts encoder output in encoder cache connector, due to nature of encode scheduler the `scheduled_output.scheduled_encoder_inputs` can contain already cached inputs or multiple same multimodal inputs, as cache is already present or going to be present we can just skip the encoding process for such `mm_inputs`.So we temporarily remove cached inputs and inputs such that their `mm_hash` already present somewhere in `scheduled_encoder_inputs`, after execution we return all removed entries back to `scheduler_output`. Motivation for sending all multimodal inputs to `model_runner` is provided in `EncoderScheduler` section.
256+
257+
Since no text generation happens here, it returns almost empty ModelRunnerOutput with additional transfered encoder outputs information in ModelRunnerOutput, this information is used in encoder scheduler to free the space in encoder cache manager.
254258

255259
#### DisaggPrefillDecodeGPURunnerWrapper (Prefill/(Prefill+Decode) Instance)
256260

vllm/separated_encode/sched/encoder_scheduler.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ def schedule(self) -> SchedulerOutput:
135135

136136
num_tokens_to_schedule = 0
137137
can_allocate_all = True
138-
encoder_inputs_to_schedule = []
138+
encoder_inputs_to_schedule = []
139+
is_cached = []
140+
139141
for input_id, pos_info in enumerate(mm_positions):
140142
num_encoder_tokens = pos_info.length
141143
if (
@@ -144,6 +146,13 @@ def schedule(self) -> SchedulerOutput:
144146
request, input_id
145147
)
146148
):
149+
# On Encoder instance we need to send all inputs to model runner
150+
# because we need to pass (req_id, input_id) to model runner's
151+
# ec connector, to send the cache to PD instance, so we will add
152+
# it to the scheduled encoder inputs without changing budget
153+
# and in model runner we will just skip all calculated values
154+
encoder_inputs_to_schedule.append(input_id)
155+
is_cached.append(True)
147156
continue
148157
if not self.encoder_cache_manager.can_allocate(
149158
request=request,
@@ -156,6 +165,7 @@ def schedule(self) -> SchedulerOutput:
156165
num_tokens_to_schedule += num_encoder_tokens
157166
new_encoder_compute_budget -= num_encoder_tokens
158167
encoder_inputs_to_schedule.append(input_id)
168+
is_cached.append(False)
159169

160170
# NOTE: Note that all updates from loop above are not applied
161171
# if we can't allocate all mm_inputs
@@ -179,10 +189,11 @@ def schedule(self) -> SchedulerOutput:
179189
scheduled_encoder_inputs[req_id] = encoder_inputs_to_schedule
180190

181191
# Allocate the encoder cache.
182-
for input_id in encoder_inputs_to_schedule:
192+
for input_id, is_cached_input in zip(encoder_inputs_to_schedule, is_cached):
183193
mm_hash = request.mm_hashes[input_id]
184194
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
185-
self.encoder_cache_manager.allocate(request, input_id)
195+
if not is_cached_input:
196+
self.encoder_cache_manager.allocate(request, input_id)
186197
self.ec_connector.schedule_send_encoder_cache_metadata(
187198
req_id,
188199
input_id,
@@ -216,7 +227,6 @@ def schedule(self) -> SchedulerOutput:
216227
structured_output_request_ids={},
217228
grammar_bitmask=None,
218229
)
219-
logger.debug(f"Request (8) ")
220230

221231
self.finished_req_ids = set()
222232
return scheduler_output

vllm/separated_encode/worker/gpu_epd_vm_wrapper.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,30 @@ def execute_model(
112112
scheduler.
113113
"""
114114
self._update_states(scheduler_output)
115+
old_scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
116+
new_scheduled_encoder_inputs = {}
117+
118+
# Erase cached inputs to execute mm encoder without repeated cache inputs
119+
going_to_be_executed = set()
120+
for req_id, mm_input_ids in old_scheduled_encoder_inputs.items():
121+
mm_hashes = self.requests[req_id].mm_hashes
122+
uncached_inputs = []
123+
for input_id in mm_input_ids:
124+
mm_hash = mm_hashes[input_id]
125+
if ((not mm_hash in self.encoder_cache)
126+
and (mm_hash not in going_to_be_executed)):
127+
uncached_inputs.append(input_id)
128+
going_to_be_executed.add(mm_hash)
129+
new_scheduled_encoder_inputs[req_id] = uncached_inputs
130+
131+
scheduler_output.scheduled_encoder_inputs = new_scheduled_encoder_inputs
132+
115133
self._execute_mm_encoder(scheduler_output)
116-
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
134+
135+
scheduler_output.scheduled_encoder_inputs = old_scheduled_encoder_inputs
136+
del new_scheduled_encoder_inputs
137+
138+
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
117139

118140
for req_id, mm_input_ids in scheduled_encoder_inputs.items():
119141
mm_hashes = self.requests[req_id].mm_hashes

0 commit comments

Comments
 (0)