Skip to content

Commit 1df6a4f

Browse files
committed
feat: contiguous encoder cache
Signed-off-by: Kero Liang <kerorek@outlook.com>
1 parent 136a17f commit 1df6a4f

File tree

3 files changed

+60
-18
lines changed

3 files changed

+60
-18
lines changed

vllm/multimodal/inputs.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from abc import ABC, abstractmethod
55
from collections import UserDict, defaultdict
66
from collections.abc import Mapping, Sequence
7-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
88
from functools import partial
99
from itertools import accumulate
1010
from typing import (
@@ -167,11 +167,42 @@ class PlaceholderRange:
167167
between `offset` and `offset + length` to assign embeddings to.
168168
"""
169169

170-
def get_num_embeds(self) -> int:
170+
num_embeds: int = field(init=False)
171+
"""
172+
The number of positions that actually result in an output from the encoder.
173+
"""
174+
175+
def __post_init__(self):
171176
if self.is_embed is None:
172-
return self.length
177+
object.__setattr__(self, "num_embeds", self.length)
178+
else:
179+
num_embeds = int(self.is_embed.sum().item())
180+
object.__setattr__(self, "num_embeds", num_embeds)
181+
182+
# Remove leading & tailing False in `is_embed` for easier scheduling
183+
if num_embeds > 0:
184+
true_indices = torch.nonzero(self.is_embed, as_tuple=True)[0]
185+
first_true_index = true_indices[0].item()
186+
last_true_index = true_indices[-1].item()
173187

174-
return int(self.is_embed.sum().item())
188+
start_trim_count = first_true_index
189+
new_length = last_true_index - first_true_index + 1
190+
191+
object.__setattr__(self, "offset", self.offset + start_trim_count)
192+
object.__setattr__(self, "length", new_length)
193+
194+
object.__setattr__(
195+
self,
196+
"is_embed",
197+
self.is_embed[first_true_index : last_true_index + 1],
198+
)
199+
else:
200+
# Seems impossible?
201+
object.__setattr__(self, "length", 0)
202+
object.__setattr__(self, "is_embed", self.is_embed[0:0])
203+
204+
def get_num_embeds(self) -> int:
205+
return self.num_embeds
175206

176207
def __eq__(self, other: object) -> bool:
177208
if not isinstance(other, self.__class__):

vllm/v1/request.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def get_finished_reason(self) -> FinishReason | None:
192192

193193
def get_num_encoder_tokens(self, input_id: int) -> int:
194194
assert input_id < len(self.mm_features)
195-
num_tokens = self.mm_features[input_id].mm_position.length
195+
num_tokens = self.mm_features[input_id].mm_position.num_embeds
196196
return num_tokens
197197

198198
def record_event(

vllm/v1/worker/gpu_model_runner.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,7 @@
148148
MultiModalBudget,
149149
add_kv_sharing_layers_to_kv_cache_groups,
150150
bind_kv_cache,
151-
gather_mm_placeholders,
152151
sanity_check_mm_encoder_outputs,
153-
scatter_mm_placeholders,
154152
)
155153

156154
if TYPE_CHECKING:
@@ -1774,10 +1772,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):
17741772

17751773
# Cache the encoder outputs by mm_hash
17761774
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
1777-
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
1778-
output,
1779-
is_embed=pos_info.is_embed,
1780-
)
1775+
self.encoder_cache[mm_hash] = output
17811776

17821777
def _gather_mm_embeddings(
17831778
self,
@@ -1828,20 +1823,36 @@ def _gather_mm_embeddings(
18281823
encoder_output = self.encoder_cache.get(mm_hash, None)
18291824
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."
18301825

1831-
if (is_embed := pos_info.is_embed) is not None:
1826+
is_embed = pos_info.is_embed
1827+
1828+
# retrieve `encoder_output` slice based on `is_embed` mask
1829+
encoder_output_slice_start = start_idx
1830+
encoder_output_slice_end = end_idx
1831+
if is_embed is not None:
1832+
num_encoder_output_before_start = is_embed[:start_idx].sum().item()
1833+
num_encoder_output_selected = (
1834+
is_embed[start_idx:end_idx].sum().item()
1835+
)
1836+
1837+
encoder_output_slice_start = num_encoder_output_before_start
1838+
encoder_output_slice_end = (
1839+
num_encoder_output_before_start + num_encoder_output_selected
1840+
)
1841+
1842+
mm_embeds_item = encoder_output[
1843+
encoder_output_slice_start:encoder_output_slice_end
1844+
]
1845+
mm_embeds_req.append(mm_embeds_item)
1846+
1847+
# append `is_mm_embed` mask
1848+
if is_embed is not None:
18321849
is_embed = is_embed[start_idx:end_idx]
18331850

18341851
req_start_pos = req_start_idx + start_pos - num_computed_tokens
18351852
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
18361853
True if is_embed is None else is_embed
18371854
)
18381855

1839-
mm_embeds_item = gather_mm_placeholders(
1840-
encoder_output[start_idx:end_idx],
1841-
is_embed=is_embed,
1842-
)
1843-
mm_embeds_req.append(mm_embeds_item)
1844-
18451856
if self.is_multimodal_pruning_enabled and self.uses_mrope:
18461857
assert req_state.mrope_positions is not None
18471858
should_sync_mrope_positions = True

0 commit comments

Comments
 (0)