Skip to content

Commit 48f07a7

Browse files
WoosukKwonywang96
authored andcommitted
[V1] Support VLMs with fine-grained scheduling (vllm-project#9871)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: Roger Wang <ywang@roblox.com>
1 parent e770bfa commit 48f07a7

File tree

12 files changed

+542
-96
lines changed

12 files changed

+542
-96
lines changed

vllm/model_executor/models/gpt2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,11 @@ def forward(
216216
kv_caches: List[torch.Tensor],
217217
attn_metadata: AttentionMetadata,
218218
intermediate_tensors: Optional[IntermediateTensors],
219+
inputs_embeds: Optional[torch.Tensor],
219220
) -> Union[torch.Tensor, IntermediateTensors]:
220221
if get_pp_group().is_first_rank:
221-
inputs_embeds = self.wte(input_ids)
222+
if inputs_embeds is None:
223+
inputs_embeds = self.wte(input_ids)
222224
position_embeds = self.wpe(position_ids)
223225
hidden_states = inputs_embeds + position_embeds
224226
else:
@@ -263,16 +265,21 @@ def __init__(
263265
self.make_empty_intermediate_tensors = (
264266
self.transformer.make_empty_intermediate_tensors)
265267

268+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
269+
return self.transformer.wte(input_ids)
270+
266271
def forward(
267272
self,
268273
input_ids: torch.Tensor,
269274
positions: torch.Tensor,
270275
kv_caches: List[torch.Tensor],
271276
attn_metadata: AttentionMetadata,
272277
intermediate_tensors: Optional[IntermediateTensors] = None,
278+
inputs_embeds: Optional[torch.Tensor] = None,
273279
) -> Union[torch.Tensor, IntermediateTensors]:
274280
hidden_states = self.transformer(input_ids, positions, kv_caches,
275-
attn_metadata, intermediate_tensors)
281+
attn_metadata, intermediate_tensors,
282+
inputs_embeds)
276283
return hidden_states
277284

278285
def compute_logits(

vllm/model_executor/models/llama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,16 +538,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
538538
normalize=False,
539539
softmax=False)
540540

541+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
542+
return self.model.get_input_embeddings(input_ids)
543+
541544
def forward(
542545
self,
543546
input_ids: torch.Tensor,
544547
positions: torch.Tensor,
545548
kv_caches: List[torch.Tensor],
546549
attn_metadata: AttentionMetadata,
547550
intermediate_tensors: Optional[IntermediateTensors] = None,
551+
inputs_embeds: Optional[torch.Tensor] = None,
548552
) -> Union[torch.Tensor, IntermediateTensors]:
549553
model_output = self.model(input_ids, positions, kv_caches,
550-
attn_metadata, intermediate_tensors)
554+
attn_metadata, intermediate_tensors,
555+
inputs_embeds)
551556
return model_output
552557

553558
def compute_logits(

vllm/model_executor/models/llava.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1818
from vllm.model_executor.sampling_metadata import SamplingMetadata
1919
from vllm.multimodal import MULTIMODAL_REGISTRY
20+
from vllm.multimodal.base import NestedTensors
2021
from vllm.sequence import IntermediateTensors
2122
from vllm.utils import is_list_of
2223

@@ -448,13 +449,33 @@ def _process_image_input(self,
448449
image_features = self._process_image_pixels(image_input)
449450
return self.multi_modal_projector(image_features)
450451

452+
def process_mm_inputs(self, **kwargs):
453+
image_input = self._parse_and_validate_image_input(**kwargs)
454+
if image_input is None:
455+
return None
456+
vision_embeddings = self._process_image_input(image_input)
457+
return vision_embeddings
458+
459+
def get_input_embeddings(
460+
self,
461+
input_ids: torch.Tensor,
462+
vision_embeddings: Optional[NestedTensors] = None,
463+
) -> torch.Tensor:
464+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
465+
if vision_embeddings is not None:
466+
inputs_embeds = merge_multimodal_embeddings(
467+
input_ids, inputs_embeds, vision_embeddings,
468+
self.config.image_token_index)
469+
return inputs_embeds
470+
451471
def forward(
452472
self,
453473
input_ids: torch.Tensor,
454474
positions: torch.Tensor,
455475
kv_caches: List[torch.Tensor],
456476
attn_metadata: AttentionMetadata,
457477
intermediate_tensors: Optional[IntermediateTensors] = None,
478+
inputs_embeds: Optional[torch.Tensor] = None,
458479
**kwargs: object,
459480
) -> Union[torch.Tensor, IntermediateTensors]:
460481
"""Run forward pass for LLaVA-1.5.
@@ -494,24 +515,13 @@ def forward(
494515
"""
495516
if intermediate_tensors is not None:
496517
inputs_embeds = None
497-
else:
498-
image_input = self._parse_and_validate_image_input(**kwargs)
499-
if image_input is not None:
500-
vision_embeddings = self._process_image_input(image_input)
501-
inputs_embeds = self.language_model.model.get_input_embeddings(
502-
input_ids)
503-
504-
inputs_embeds = merge_multimodal_embeddings(
505-
input_ids, inputs_embeds, vision_embeddings,
506-
self.config.image_token_index)
507-
else:
508-
inputs_embeds = self.language_model.model.get_input_embeddings(
509-
input_ids)
510-
511-
# always pass the input via `inputs_embeds`
512-
# to make sure the computation graph is consistent
513-
# for `torch.compile` integration
514-
input_ids = None
518+
elif inputs_embeds is None:
519+
vision_embeddings = self.process_mm_inputs(**kwargs)
520+
# always pass the input via `inputs_embeds`
521+
# to make sure the computation graph is consistent
522+
inputs_embeds = self.get_input_embeddings(input_ids,
523+
vision_embeddings)
524+
input_ids = None
515525

516526
hidden_states = self.language_model.model(input_ids,
517527
positions,

vllm/model_executor/models/opt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
360360
self.make_empty_intermediate_tensors = (
361361
self.model.make_empty_intermediate_tensors)
362362

363+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
364+
return self.model.get_input_embeddings(input_ids)
365+
363366
def forward(
364367
self,
365368
input_ids: torch.Tensor,
366369
positions: torch.Tensor,
367370
kv_caches: List[torch.Tensor],
368371
attn_metadata: AttentionMetadata,
369372
intermediate_tensors: Optional[IntermediateTensors] = None,
373+
inputs_embeds: Optional[torch.Tensor] = None,
370374
) -> Union[torch.Tensor, IntermediateTensors]:
371375
hidden_states = self.model(input_ids, positions, kv_caches,
372-
attn_metadata, intermediate_tensors)
376+
attn_metadata, intermediate_tensors,
377+
inputs_embeds)
373378
return hidden_states
374379

375380
def compute_logits(

vllm/model_executor/models/phi3v.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.pooling_metadata import PoolingMetadata
4040
from vllm.model_executor.sampling_metadata import SamplingMetadata
4141
from vllm.multimodal import MULTIMODAL_REGISTRY
42+
from vllm.multimodal.base import NestedTensors, PlaceholderRange
4243
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
4344
from vllm.sequence import IntermediateTensors, PoolerOutput
4445
from vllm.utils import is_list_of
@@ -500,23 +501,29 @@ def input_processor_for_phi3v(ctx: InputContext,
500501

501502
# TODO: Move this to utils or integrate with clip.
502503
new_token_ids: List[int] = []
504+
placeholder_ranges: List[PlaceholderRange] = []
503505
placeholder_idx = 0
504506
while merged_token_ids:
505507
token_id = merged_token_ids.pop(0)
506508
if token_id == _IMAGE_TOKEN_ID:
507-
new_token_ids.extend(
508-
repeat_and_pad_token(
509-
_IMAGE_TOKEN_ID,
510-
repeat_count=image_feature_size[placeholder_idx],
511-
))
509+
replacement_ids = repeat_and_pad_token(
510+
_IMAGE_TOKEN_ID,
511+
repeat_count=image_feature_size[placeholder_idx],
512+
)
513+
placeholder_ranges.append({
514+
"offset": len(new_token_ids),
515+
"length": len(replacement_ids)
516+
})
517+
new_token_ids.extend(replacement_ids)
512518
placeholder_idx += 1
513519
else:
514520
new_token_ids.append(token_id)
515521

516522
# NOTE: Create a defensive copy of the original inputs
517523
return token_inputs(prompt_token_ids=new_token_ids,
518524
prompt=new_prompt,
519-
multi_modal_data=multi_modal_data)
525+
multi_modal_data=multi_modal_data,
526+
multi_modal_placeholders={"image": placeholder_ranges})
520527

521528

522529
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@@ -669,32 +676,42 @@ def _process_image_input(
669676

670677
return image_embeds
671678

679+
def process_mm_inputs(self, **kwargs):
680+
image_input = self._parse_and_validate_image_input(**kwargs)
681+
if image_input is None:
682+
return None
683+
vision_embeddings = self._process_image_input(image_input)
684+
return vision_embeddings
685+
686+
def get_input_embeddings(
687+
self,
688+
input_ids: torch.Tensor,
689+
vision_embeddings: Optional[NestedTensors] = None,
690+
) -> torch.Tensor:
691+
inputs_embeds = self.embed_tokens(input_ids)
692+
if vision_embeddings is not None:
693+
inputs_embeds = merge_multimodal_embeddings(
694+
input_ids, inputs_embeds, vision_embeddings,
695+
self.image_token_id)
696+
return inputs_embeds
697+
672698
def forward(self,
673699
input_ids: torch.Tensor,
674700
positions: torch.Tensor,
675701
kv_caches: List[torch.Tensor],
676702
attn_metadata: AttentionMetadata,
677703
intermediate_tensors: Optional[IntermediateTensors] = None,
704+
inputs_embeds: Optional[torch.Tensor] = None,
678705
**kwargs: object):
679706
if intermediate_tensors is not None:
680707
inputs_embeds = None
681-
else:
682-
image_input = self._parse_and_validate_image_input(**kwargs)
683-
684-
if image_input is not None:
685-
vision_embeddings = self._process_image_input(image_input)
686-
inputs_embeds = self.embed_tokens(input_ids)
687-
inputs_embeds = merge_multimodal_embeddings(
688-
input_ids, inputs_embeds, vision_embeddings,
689-
self.image_token_id)
690-
else:
691-
inputs_embeds = self.language_model.model.embed_tokens(
692-
input_ids)
693-
694-
# always pass the input via `inputs_embeds`
695-
# to make sure the computation graph is consistent
696-
# for `torch.compile` integration
697-
input_ids = None
708+
elif inputs_embeds is None:
709+
vision_embeddings = self.process_mm_inputs(**kwargs)
710+
# always pass the input via `inputs_embeds`
711+
# to make sure the computation graph is consistent
712+
inputs_embeds = self.get_input_embeddings(input_ids,
713+
vision_embeddings)
714+
input_ids = None
698715

699716
hidden_states = self.language_model.model(input_ids,
700717
positions,

vllm/model_executor/models/qwen2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,16 +441,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
441441
self.make_empty_intermediate_tensors = (
442442
self.model.make_empty_intermediate_tensors)
443443

444+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
445+
return self.model.get_input_embeddings(input_ids)
446+
444447
def forward(
445448
self,
446449
input_ids: torch.Tensor,
447450
positions: torch.Tensor,
448451
kv_caches: List[torch.Tensor],
449452
attn_metadata: AttentionMetadata,
450453
intermediate_tensors: Optional[IntermediateTensors] = None,
454+
inputs_embeds: Optional[torch.Tensor] = None,
451455
) -> Union[torch.Tensor, IntermediateTensors]:
452456
hidden_states = self.model(input_ids, positions, kv_caches,
453-
attn_metadata, intermediate_tensors)
457+
attn_metadata, intermediate_tensors,
458+
inputs_embeds)
454459
return hidden_states
455460

456461
def compute_logits(
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Dict, List, Set, Tuple
2+
3+
from vllm.v1.request import Request
4+
5+
6+
class EncoderCacheManager:
7+
8+
def __init__(self, cache_size: int):
9+
self.cache_size = cache_size
10+
self.num_free_slots = cache_size
11+
# req_id -> cached input ids
12+
self.cached: Dict[str, Set[int]] = {}
13+
# List of [req_id, input_id]
14+
self.freed: List[Tuple[str, int]] = []
15+
16+
def has_cache(self, request: Request, input_id: int) -> bool:
17+
req_id = request.request_id
18+
return req_id in self.cached and input_id in self.cached[req_id]
19+
20+
def can_allocate(self, request: Request, input_id: int) -> bool:
21+
num_tokens = request.get_num_encoder_tokens(input_id)
22+
return num_tokens <= self.num_free_slots
23+
24+
def allocate(self, request: Request, input_id: int) -> None:
25+
req_id = request.request_id
26+
if req_id not in self.cached:
27+
self.cached[req_id] = set()
28+
self.cached[req_id].add(input_id)
29+
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
30+
31+
def get_cached_input_ids(self, request: Request) -> Set[int]:
32+
return self.cached.get(request.request_id, set())
33+
34+
def free(self, request: Request, input_id: int) -> None:
35+
req_id = request.request_id
36+
if req_id not in self.cached:
37+
return
38+
39+
self.cached[req_id].discard(input_id)
40+
if len(self.cached[req_id]) == 0:
41+
del self.cached[req_id]
42+
self.num_free_slots += request.get_num_encoder_tokens(input_id)
43+
self.freed.append((req_id, input_id))
44+
45+
def get_freed_ids(self) -> List[Tuple[str, int]]:
46+
freed = self.freed
47+
self.freed = []
48+
return freed

0 commit comments

Comments
 (0)