|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -from typing import List, Mapping, Optional, Union |
| 4 | +from typing import List, Mapping, Optional, Tuple, Union, cast |
5 | 5 |
|
6 | 6 | from typing_extensions import assert_never |
7 | 7 |
|
8 | 8 | from vllm.config import ModelConfig |
9 | 9 | from vllm.logger import init_logger |
10 | 10 | from vllm.lora.request import LoRARequest |
11 | 11 | from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry |
12 | | -from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs |
| 12 | +from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, |
| 13 | + MultiModalInputs) |
13 | 14 | from vllm.prompt_adapter.request import PromptAdapterRequest |
14 | 15 | from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup |
15 | 16 |
|
@@ -495,6 +496,51 @@ def _build_enc_dec_llm_inputs( |
495 | 496 | decoder=decoder_inputs, |
496 | 497 | ) |
497 | 498 |
|
| 499 | + def _separate_enc_dec_inputs_from_mm_processor_outputs( |
| 500 | + self, |
| 501 | + inputs: SingletonInputs, |
| 502 | + decoder_inputs_to_override: Optional[SingletonInputs] = None, |
| 503 | + ) -> Tuple[SingletonInputs, SingletonInputs]: |
| 504 | + """ |
| 505 | + For encoder/decoder models only: |
| 506 | + Separate Encoder/Decoder inputs from a MultiModalEncDecInputs |
| 507 | + """ |
| 508 | + encoder_inputs: SingletonInputs |
| 509 | + decoder_inputs: SingletonInputs |
| 510 | + if inputs["type"] == "multimodal": |
| 511 | + # Multimodal data inputs |
| 512 | + assert ("encoder_prompt" in inputs |
| 513 | + and "encoder_prompt_token_ids" in inputs) |
| 514 | + inputs = cast(MultiModalEncDecInputs, inputs) |
| 515 | + encoder_inputs = token_inputs( |
| 516 | + prompt=inputs["encoder_prompt"], |
| 517 | + prompt_token_ids=inputs["encoder_prompt_token_ids"], |
| 518 | + ) |
| 519 | + if decoder_inputs_to_override is not None: |
| 520 | + decoder_inputs = MultiModalInputs( |
| 521 | + type="multimodal", |
| 522 | + prompt=decoder_inputs_to_override.get("prompt", ""), |
| 523 | + prompt_token_ids=decoder_inputs_to_override[ |
| 524 | + "prompt_token_ids"], |
| 525 | + mm_kwargs=inputs["mm_kwargs"], |
| 526 | + mm_placeholders=inputs["mm_placeholders"], |
| 527 | + ) |
| 528 | + else: |
| 529 | + decoder_inputs = MultiModalInputs( |
| 530 | + type="multimodal", |
| 531 | + prompt=inputs["prompt"], |
| 532 | + prompt_token_ids=inputs["prompt_token_ids"], |
| 533 | + mm_kwargs=inputs["mm_kwargs"], |
| 534 | + mm_placeholders=inputs["mm_placeholders"], |
| 535 | + ) |
| 536 | + elif inputs["type"] == "token": |
| 537 | + # Text-only inputs |
| 538 | + encoder_inputs = token_inputs(prompt="", prompt_token_ids=[]) |
| 539 | + decoder_inputs = decoder_inputs_to_override or inputs |
| 540 | + else: |
| 541 | + assert_never(inputs) # type: ignore[arg-type] |
| 542 | + return encoder_inputs, decoder_inputs |
| 543 | + |
498 | 544 | def _process_encoder_decoder_prompt( |
499 | 545 | self, |
500 | 546 | prompt: PromptType, |
@@ -539,21 +585,35 @@ def _process_encoder_decoder_prompt( |
539 | 585 | prompt["encoder_prompt"], |
540 | 586 | request_id=request_id, |
541 | 587 | ) |
542 | | - |
543 | 588 | if (decoder_input := prompt["decoder_prompt"]) is None: |
544 | 589 | decoder_inputs = None |
545 | 590 | else: |
546 | 591 | decoder_inputs = self._prompt_to_llm_inputs( |
547 | 592 | decoder_input, |
548 | 593 | request_id=request_id, |
549 | 594 | ) |
| 595 | + # For multimodal model, override decoder prompt from processor |
| 596 | + # with explicit decoder prompt. |
| 597 | + if self.model_config.is_multimodal_model and ( |
| 598 | + self._can_process_multimodal()): |
| 599 | + encoder_inputs, decoder_inputs = ( |
| 600 | + self._separate_enc_dec_inputs_from_mm_processor_outputs( |
| 601 | + encoder_inputs, decoder_inputs)) |
550 | 602 | else: |
551 | | - encoder_inputs = self._prompt_to_llm_inputs( |
| 603 | + inputs = self._prompt_to_llm_inputs( |
552 | 604 | prompt, |
553 | 605 | request_id=request_id, |
554 | 606 | ) |
| 607 | + if self.model_config.is_multimodal_model and ( |
| 608 | + self._can_process_multimodal()): |
| 609 | + # Encoder-Decoder Multimodal model |
| 610 | + encoder_inputs, decoder_inputs = ( |
| 611 | + self._separate_enc_dec_inputs_from_mm_processor_outputs( |
| 612 | + inputs)) |
| 613 | + else: |
| 614 | + encoder_inputs = inputs |
555 | 615 |
|
556 | | - decoder_inputs = None |
| 616 | + decoder_inputs = None |
557 | 617 |
|
558 | 618 | return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) |
559 | 619 |
|
@@ -583,13 +643,29 @@ async def _process_encoder_decoder_prompt_async( |
583 | 643 |
|
584 | 644 | encoder_inputs, decoder_inputs = await asyncio.gather( |
585 | 645 | encoder_task, decoder_task) |
| 646 | + |
| 647 | + # For multimodal model, override decoder prompt from processor |
| 648 | + # with explicit decoder prompt. |
| 649 | + if self.model_config.is_multimodal_model and ( |
| 650 | + self._can_process_multimodal()): |
| 651 | + encoder_inputs, decoder_inputs = ( |
| 652 | + self._separate_enc_dec_inputs_from_mm_processor_outputs( |
| 653 | + encoder_inputs, decoder_inputs)) |
586 | 654 | else: |
587 | | - encoder_inputs = await self._prompt_to_llm_inputs_async( |
| 655 | + inputs = await self._prompt_to_llm_inputs_async( |
588 | 656 | prompt, |
589 | 657 | request_id=request_id, |
590 | 658 | ) |
| 659 | + if self.model_config.is_multimodal_model and ( |
| 660 | + self._can_process_multimodal()): |
| 661 | + # Encoder-Decoder Multimodal model |
| 662 | + encoder_inputs, decoder_inputs = ( |
| 663 | + self._separate_enc_dec_inputs_from_mm_processor_outputs( |
| 664 | + inputs)) |
| 665 | + else: |
| 666 | + encoder_inputs = inputs |
591 | 667 |
|
592 | | - decoder_inputs = None |
| 668 | + decoder_inputs = None |
593 | 669 |
|
594 | 670 | return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs) |
595 | 671 |
|
|
0 commit comments