Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][VLM] Enable proper chunked prefill for multimodal models #9950

Closed
wants to merge 17 commits into from
14 changes: 8 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1129,18 +1129,20 @@ def __init__(self,
# for higher throughput.
max_num_batched_tokens = max(max_model_len, 2048)

if is_multimodal_model:
# The value needs to be at least the number of multimodal
# tokens if chunked prefill is not enabled.
max_num_batched_tokens = max(
max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

if task == "embedding":
# For embedding, choose specific value for higher throughput
max_num_batched_tokens = max(
max_num_batched_tokens,
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
max_num_batched_tokens = max(
max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

self.max_num_batched_tokens = max_num_batched_tokens

Expand Down
51 changes: 35 additions & 16 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,13 @@ def input_processor_for_blip2(ctx: InputContext, inputs: DecoderOnlyInputs):
if new_prompt is not None:
new_prompt = BLIP2_IMAGE_TOKEN * image_feature_size + new_prompt

ranges = consecutive_placeholder_ranges(num_items=1,
item_size=image_feature_size)

return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": ranges})


@MULTIMODAL_REGISTRY.register_image_input_mapper()
Expand Down Expand Up @@ -609,13 +613,34 @@ def _process_image_input(self,

return self.language_projection(query_output)

def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self, input_ids: torch.Tensor,
vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=vision_embeddings,
placeholder_token_id=BLIP2_IMAGE_TOKEN_ID)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
Expand Down Expand Up @@ -648,24 +673,18 @@ def forward(
See also:
:class:`Blip2ImageInputs`
"""

if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
BLIP2_IMAGE_TOKEN_ID)

input_ids = None
else:
inputs_embeds = None
# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids=input_ids, vision_embeddings=vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(
input_ids,
Expand Down
65 changes: 42 additions & 23 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,12 @@ def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token

ranges = consecutive_placeholder_ranges(num_items=len(image_list),
item_size=image_patches.size(1))
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
multi_modal_data=new_multi_modal_data,
multi_modal_placeholders={"image": ranges})


def input_mapper_for_fuyu(ctx: InputContext, data: object):
Expand Down Expand Up @@ -298,40 +301,56 @@ def _process_image_input(
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings

def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self, input_ids: torch.Tensor,
vision_embeddings: Optional[torch.Tensor]) -> torch.Tensor:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=vision_embeddings,
placeholder_token_id=self.image_token_id)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
):
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)

if image_input is not None:
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(
input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.image_token_id)

else:
inputs_embeds = None

hidden_states = self.language_model(
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,

# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids=input_ids, vision_embeddings=vision_embeddings)
input_ids = None

hidden_states = self.language_model.model(
input_ids,
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
inputs_embeds=inputs_embeds)

return hidden_states

def compute_logits(
Expand Down
92 changes: 70 additions & 22 deletions vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
InternVisionPatchModel)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
Expand Down Expand Up @@ -321,10 +322,26 @@ def input_processor(
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
num_patches)
new_prompt_token_ids = tokenizer.encode(new_prompt)

return token_inputs(prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data)
img_context_token_id = tokenizer.encode(self.img_context_token)[1]

# Get precise tracking of placeholder positions
token_idx = image_idx = 0
placeholder_ranges = []
while token_idx < len(new_prompt_token_ids):
if new_prompt_token_ids[token_idx] == img_context_token_id:
curr_image_featue_size = image_feature_sizes[image_idx]
placeholder_ranges.append(
PlaceholderRange(offset=token_idx,
length=curr_image_featue_size))
image_idx += 1
token_idx += curr_image_featue_size
else:
token_idx += 1
return token_inputs(
prompt=prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})

def input_mapper(
self,
Expand Down Expand Up @@ -576,12 +593,11 @@ def _parse_and_validate_image_input(
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
# We need to flatten (B, N, P) to (B*N*P),
# so we call flatten_bn twice.

# Pass as a list of pixel value tensors
return InternVLImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(flatten_bn(pixel_values), concat=True)),
data=pixel_values,
)

raise AssertionError("This line should be unreachable.")
Expand All @@ -594,7 +610,15 @@ def _process_image_input(
return image_input["data"]

assert self.vision_model is not None
image_embeds = self.extract_feature(image_input["data"])

# Output as a list of image embeddings
image_embeds = []

for pixel_values in image_input["data"]:
vision_embeddings = self.extract_feature(flatten_bn(pixel_values))
hidden_size = vision_embeddings.shape[2]
vision_embeddings = vision_embeddings.reshape(1, -1, hidden_size)
image_embeds.append(vision_embeddings)

return image_embeds

Expand All @@ -606,33 +630,57 @@ def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
visual_token_mask = None
return visual_token_mask

def process_mm_inputs(self, **kwargs: object) -> Optional[torch.Tensor]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings

def get_input_embeddings(
self,
input_ids: torch.Tensor,
vision_embeddings: Optional[torch.Tensor] = None,
) -> torch.Tensor:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
if vision_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
multimodal_embeddings=vision_embeddings,
placeholder_token_id=self.img_context_token_id)

return inputs_embeds

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
visual_token_mask = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.img_context_token_id)

# TODO (ywang96): This is currently needed since embedding generation
# takes place in the model forward pass. Clean this up after V0 is
# fully deprecated.
elif inputs_embeds is None:
vision_embeddings = self.process_mm_inputs(**kwargs)
if vision_embeddings is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
input_ids = None
else:
inputs_embeds = None
visual_token_mask = None
inputs_embeds = self.get_input_embeddings(
input_ids=input_ids, vision_embeddings=vision_embeddings)
input_ids = None

else:
visual_token_mask = self._get_visual_token_mask(input_ids)

forward_kwargs = {
"input_ids": input_ids,
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -267,7 +270,7 @@ def forward(

class PersimmonForCausalLM(nn.Module, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
Expand All @@ -282,6 +285,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
self,
input_ids: torch.Tensor,
Expand Down
Loading