Skip to content

Commit 04edd1c

Browse files
committed
DCO
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 112fa0b commit 04edd1c

File tree

10 files changed

+527
-93
lines changed

10 files changed

+527
-93
lines changed

vllm/model_executor/models/llama.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,16 +538,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
538538
normalize=False,
539539
softmax=False)
540540

541+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
542+
return self.model.get_input_embeddings(input_ids)
543+
541544
def forward(
542545
self,
543546
input_ids: torch.Tensor,
544547
positions: torch.Tensor,
545548
kv_caches: List[torch.Tensor],
546549
attn_metadata: AttentionMetadata,
547550
intermediate_tensors: Optional[IntermediateTensors] = None,
551+
inputs_embeds: Optional[torch.Tensor] = None,
548552
) -> Union[torch.Tensor, IntermediateTensors]:
549553
model_output = self.model(input_ids, positions, kv_caches,
550-
attn_metadata, intermediate_tensors)
554+
attn_metadata, intermediate_tensors,
555+
inputs_embeds)
551556
return model_output
552557

553558
def compute_logits(

vllm/model_executor/models/llava.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
1818
from vllm.model_executor.sampling_metadata import SamplingMetadata
1919
from vllm.multimodal import MULTIMODAL_REGISTRY
20+
from vllm.multimodal.base import NestedTensors
2021
from vllm.sequence import IntermediateTensors
2122
from vllm.utils import is_list_of
2223

@@ -448,13 +449,33 @@ def _process_image_input(self,
448449
image_features = self._process_image_pixels(image_input)
449450
return self.multi_modal_projector(image_features)
450451

452+
def process_mm_inputs(self, **kwargs):
453+
image_input = self._parse_and_validate_image_input(**kwargs)
454+
if image_input is None:
455+
return None
456+
vision_embeddings = self._process_image_input(image_input)
457+
return vision_embeddings
458+
459+
def get_input_embeddings(
460+
self,
461+
input_ids: torch.Tensor,
462+
vision_embeddings: Optional[NestedTensors] = None,
463+
) -> torch.Tensor:
464+
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
465+
if vision_embeddings is not None:
466+
inputs_embeds = merge_multimodal_embeddings(
467+
input_ids, inputs_embeds, vision_embeddings,
468+
self.config.image_token_index)
469+
return inputs_embeds
470+
451471
def forward(
452472
self,
453473
input_ids: torch.Tensor,
454474
positions: torch.Tensor,
455475
kv_caches: List[torch.Tensor],
456476
attn_metadata: AttentionMetadata,
457477
intermediate_tensors: Optional[IntermediateTensors] = None,
478+
inputs_embeds: Optional[torch.Tensor] = None,
458479
**kwargs: object,
459480
) -> Union[torch.Tensor, IntermediateTensors]:
460481
"""Run forward pass for LLaVA-1.5.
@@ -494,24 +515,13 @@ def forward(
494515
"""
495516
if intermediate_tensors is not None:
496517
inputs_embeds = None
497-
else:
498-
image_input = self._parse_and_validate_image_input(**kwargs)
499-
if image_input is not None:
500-
vision_embeddings = self._process_image_input(image_input)
501-
inputs_embeds = self.language_model.model.get_input_embeddings(
502-
input_ids)
503-
504-
inputs_embeds = merge_multimodal_embeddings(
505-
input_ids, inputs_embeds, vision_embeddings,
506-
self.config.image_token_index)
507-
else:
508-
inputs_embeds = self.language_model.model.get_input_embeddings(
509-
input_ids)
510-
511-
# always pass the input via `inputs_embeds`
512-
# to make sure the computation graph is consistent
513-
# for `torch.compile` integration
514-
input_ids = None
518+
elif inputs_embeds is None:
519+
vision_embeddings = self.process_mm_inputs(**kwargs)
520+
# always pass the input via `inputs_embeds`
521+
# to make sure the computation graph is consistent
522+
inputs_embeds = self.get_input_embeddings(input_ids,
523+
vision_embeddings)
524+
input_ids = None
515525

516526
hidden_states = self.language_model.model(input_ids,
517527
positions,

vllm/model_executor/models/opt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,16 +360,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
360360
self.make_empty_intermediate_tensors = (
361361
self.model.make_empty_intermediate_tensors)
362362

363+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
364+
return self.model.get_input_embeddings(input_ids)
365+
363366
def forward(
364367
self,
365368
input_ids: torch.Tensor,
366369
positions: torch.Tensor,
367370
kv_caches: List[torch.Tensor],
368371
attn_metadata: AttentionMetadata,
369372
intermediate_tensors: Optional[IntermediateTensors] = None,
373+
inputs_embeds: Optional[torch.Tensor] = None,
370374
) -> Union[torch.Tensor, IntermediateTensors]:
371375
hidden_states = self.model(input_ids, positions, kv_caches,
372-
attn_metadata, intermediate_tensors)
376+
attn_metadata, intermediate_tensors,
377+
inputs_embeds)
373378
return hidden_states
374379

375380
def compute_logits(

vllm/model_executor/models/phi3v.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.pooling_metadata import PoolingMetadata
4040
from vllm.model_executor.sampling_metadata import SamplingMetadata
4141
from vllm.multimodal import MULTIMODAL_REGISTRY
42+
from vllm.multimodal.base import NestedTensors, PlaceholderRange
4243
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
4344
from vllm.sequence import IntermediateTensors, PoolerOutput
4445
from vllm.utils import is_list_of
@@ -500,23 +501,29 @@ def input_processor_for_phi3v(ctx: InputContext,
500501

501502
# TODO: Move this to utils or integrate with clip.
502503
new_token_ids: List[int] = []
504+
placeholder_ranges: List[PlaceholderRange] = []
503505
placeholder_idx = 0
504506
while merged_token_ids:
505507
token_id = merged_token_ids.pop(0)
506508
if token_id == _IMAGE_TOKEN_ID:
507-
new_token_ids.extend(
508-
repeat_and_pad_token(
509-
_IMAGE_TOKEN_ID,
510-
repeat_count=image_feature_size[placeholder_idx],
511-
))
509+
replacement_ids = repeat_and_pad_token(
510+
_IMAGE_TOKEN_ID,
511+
repeat_count=image_feature_size[placeholder_idx],
512+
)
513+
placeholder_ranges.append({
514+
"offset": len(new_token_ids),
515+
"length": len(replacement_ids)
516+
})
517+
new_token_ids.extend(replacement_ids)
512518
placeholder_idx += 1
513519
else:
514520
new_token_ids.append(token_id)
515521

516522
# NOTE: Create a defensive copy of the original inputs
517523
return token_inputs(prompt_token_ids=new_token_ids,
518524
prompt=new_prompt,
519-
multi_modal_data=multi_modal_data)
525+
multi_modal_data=multi_modal_data,
526+
multi_modal_placeholders={"image": placeholder_ranges})
520527

521528

522529
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@@ -669,32 +676,42 @@ def _process_image_input(
669676

670677
return image_embeds
671678

679+
def process_mm_inputs(self, **kwargs):
680+
image_input = self._parse_and_validate_image_input(**kwargs)
681+
if image_input is None:
682+
return None
683+
vision_embeddings = self._process_image_input(image_input)
684+
return vision_embeddings
685+
686+
def get_input_embeddings(
687+
self,
688+
input_ids: torch.Tensor,
689+
vision_embeddings: Optional[NestedTensors] = None,
690+
) -> torch.Tensor:
691+
inputs_embeds = self.embed_tokens(input_ids)
692+
if vision_embeddings is not None:
693+
inputs_embeds = merge_multimodal_embeddings(
694+
input_ids, inputs_embeds, vision_embeddings,
695+
self.image_token_id)
696+
return inputs_embeds
697+
672698
def forward(self,
673699
input_ids: torch.Tensor,
674700
positions: torch.Tensor,
675701
kv_caches: List[torch.Tensor],
676702
attn_metadata: AttentionMetadata,
677703
intermediate_tensors: Optional[IntermediateTensors] = None,
704+
inputs_embeds: Optional[torch.Tensor] = None,
678705
**kwargs: object):
679706
if intermediate_tensors is not None:
680707
inputs_embeds = None
681-
else:
682-
image_input = self._parse_and_validate_image_input(**kwargs)
683-
684-
if image_input is not None:
685-
vision_embeddings = self._process_image_input(image_input)
686-
inputs_embeds = self.embed_tokens(input_ids)
687-
inputs_embeds = merge_multimodal_embeddings(
688-
input_ids, inputs_embeds, vision_embeddings,
689-
self.image_token_id)
690-
else:
691-
inputs_embeds = self.language_model.model.embed_tokens(
692-
input_ids)
693-
694-
# always pass the input via `inputs_embeds`
695-
# to make sure the computation graph is consistent
696-
# for `torch.compile` integration
697-
input_ids = None
708+
elif inputs_embeds is None:
709+
vision_embeddings = self.process_mm_inputs(**kwargs)
710+
# always pass the input via `inputs_embeds`
711+
# to make sure the computation graph is consistent
712+
inputs_embeds = self.get_input_embeddings(input_ids,
713+
vision_embeddings)
714+
input_ids = None
698715

699716
hidden_states = self.language_model.model(input_ids,
700717
positions,
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Dict, List, Set, Tuple
2+
3+
from vllm.v1.request import Request
4+
5+
6+
class EncoderCacheManager:
7+
8+
def __init__(self, cache_size: int):
9+
self.cache_size = cache_size
10+
self.num_free_slots = cache_size
11+
# req_id -> cached input ids
12+
self.cached: Dict[str, Set[int]] = {}
13+
# List of [req_id, input_id]
14+
self.freed: List[Tuple[str, int]] = []
15+
16+
def has_cache(self, request: Request, input_id: int) -> bool:
17+
req_id = request.request_id
18+
return req_id in self.cached and input_id in self.cached[req_id]
19+
20+
def can_allocate(self, request: Request, input_id: int) -> bool:
21+
num_tokens = request.get_num_encoder_tokens(input_id)
22+
return num_tokens <= self.num_free_slots
23+
24+
def allocate(self, request: Request, input_id: int) -> None:
25+
req_id = request.request_id
26+
if req_id not in self.cached:
27+
self.cached[req_id] = set()
28+
self.cached[req_id].add(input_id)
29+
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
30+
31+
def get_cached_input_ids(self, request: Request) -> Set[int]:
32+
return self.cached.get(request.request_id, set())
33+
34+
def free(self, request: Request, input_id: int) -> None:
35+
req_id = request.request_id
36+
if req_id not in self.cached:
37+
return
38+
39+
self.cached[req_id].discard(input_id)
40+
if len(self.cached[req_id]) == 0:
41+
del self.cached[req_id]
42+
self.num_free_slots += request.get_num_encoder_tokens(input_id)
43+
self.freed.append((req_id, input_id))
44+
45+
def get_freed_ids(self) -> List[Tuple[str, int]]:
46+
freed = self.freed
47+
self.freed = []
48+
return freed

0 commit comments

Comments
 (0)