From 415b3d431165f99f840cd044c21cf715dd391036 Mon Sep 17 00:00:00 2001
From: DarkLight1337 <tlleungac@connect.ust.hk>
Date: Wed, 1 Jan 2025 17:52:58 +0000
Subject: [PATCH 1/3] Support V1 multi-input in merged multi-modal processor

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
---
 vllm/multimodal/inputs.py     | 246 +++++++++++++++-------------------
 vllm/multimodal/processing.py |  45 +++----
 vllm/v1/engine/processor.py   |  19 ++-
 3 files changed, 142 insertions(+), 168 deletions(-)

diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index db489af7ac475..c41cf57e910b7 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -2,7 +2,8 @@
 from collections import UserDict, defaultdict
 from collections.abc import Mapping, Sequence
 from dataclasses import dataclass
-from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
+from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
+                    final)
 
 import numpy as np
 import torch
@@ -11,7 +12,7 @@
 from transformers import BatchFeature
 from typing_extensions import NotRequired, TypeAlias
 
-from vllm.utils import JSONTree, is_list_of, json_map_leaves
+from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
 
 _T = TypeVar("_T")
 
@@ -160,11 +161,8 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
 
 
 @dataclass(frozen=True)
-class MultiModalFieldItem:
-    """
-    Contains metadata and data in :class:`MultiModalKwargs`
-    corresponding to a data item in :class:`MultiModalDataItems`.
-    """
+class MultiModalFieldElem:
+    """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
     field: "BaseMultiModalField"
     data: NestedTensors
 
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
     def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
         raise NotImplementedError
 
-    def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
-        return MultiModalFieldItem(self, data)
+    def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
+        return MultiModalFieldElem(self, data)
 
-    def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
-        """Merge multiple instances of :class:`MultiModalFieldItem` together."""
+    def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
+        """Merge multiple instances of :class:`MultiModalFieldElem` together."""
         fields = [item.field for item in batch]
         if len(set(fields)) > 1:
             raise ValueError(f"Cannot merge different {fields=}")
 
         data = self._reduce_data([item.data for item in batch])
 
-        return self._build_item(data)
+        return self._build_elem(data)
 
 
 @dataclass(frozen=True)
 class MultiModalBatchedField(BaseMultiModalField):
     """
-    A :class:`BaseMultiModalField` implementation where an item is obtained by
-    directly indexing into the first dimension of the underlying data.
+    A :class:`BaseMultiModalField` implementation where an element in the batch
+    is obtained by indexing into the first dimension of the underlying data.
     """
 
-    def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
-        return [self._build_item(item) for item in batch]
+    def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
+        return [self._build_elem(item) for item in batch]
 
     def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
         if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
             first_shape = batch[0].shape
-            if all(item.shape == first_shape for item in batch):
+            if all(elem.shape == first_shape for elem in batch):
                 return torch.stack(batch)
 
         return batch
@@ -222,24 +220,24 @@ def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
 @dataclass(frozen=True)
 class MultiModalFlatField(BaseMultiModalField):
     """
-    A :class:`BaseMultiModalField` implementation where an item is obtained by
-    slicing along the first dimension of the underlying data.
+    A :class:`BaseMultiModalField` implementation where an element in the batch
+    is obtained by slicing along the first dimension of the underlying data.
     """
 
-    def build_items(
+    def build_elems(
         self,
         batch: NestedTensors,
         slices: Sequence[slice],
-    ) -> list[MultiModalFieldItem]:
-        return [self._build_item(batch[slice_]) for slice_ in slices]
+    ) -> list[MultiModalFieldElem]:
+        return [self._build_elem(batch[slice_]) for slice_ in slices]
 
     def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
         if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
             first_shape = batch[0].shape
-            if all(item.shape[1:] == first_shape[1:] for item in batch):
+            if all(elem.shape[1:] == first_shape[1:] for elem in batch):
                 return torch.concat(batch)
 
-        return [elem for item in batch for elem in item]
+        return [e for elem in batch for e in elem]
 
 
 class MultiModalFieldConfig:
@@ -267,115 +265,111 @@ def __init__(
     ) -> None:
         super().__init__()
 
-        self._field_cls = field_cls
-        self._modality = modality
-        self._field_config = field_config
+        self.field_cls = field_cls
+        self.modality = modality
+        self.field_config = field_config
 
-    def build_items(
+    def build_elems(
         self,
         key: str,
         batch: NestedTensors,
-    ) -> list[MultiModalFieldItem]:
-        field = self._field_cls(key=key, modality=self._modality)
-        return field.build_items(batch, **self._field_config)  # type: ignore
+    ) -> list[MultiModalFieldElem]:
+        field = self.field_cls(key=key, modality=self.modality)
+        return field.build_elems(batch, **self.field_config)  # type: ignore
 
 
-class MultiModalKwargs(UserDict[str, NestedTensors]):
+class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
+    """
+    A collection of :class:`MultiModalFieldElem`
+    corresponding to a data item in :class:`MultiModalDataItems`.
     """
-    A dictionary that represents the keyword arguments to
-    :meth:`~torch.nn.Module.forward`.
-
-    The metadata :code:`items_by_key` defines how to split batched keyword
-    arguments corresponding to each data item in :class:`MultiModalDataItems`:
 
-    - For a keyword argument, we can access the :code:`i` th item in the batch
-      via :code:`items_by_key[key][i]`.
-    - We can gather the keyword arguments belonging to a modality by finding
-      the keys with items that belong to that modality, then accessing
-      the :code:`i` th item in the batch for each such key.
+    @staticmethod
+    def from_elems(elems: list[MultiModalFieldElem]) -> "MultiModalKwargsItem":
+        return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
 
-    Example:
+    @property
+    def modality(self) -> str:
+        modalities = {elem.field.modality for elem in self.data.values()}
+        assert len(modalities) == 1, f"Found different modalities={modalities}"
+        return next(iter(modalities))
 
-        .. code-block:: python
 
-            # All items belong to the "image" modality
-            items_by_key={
-                "pixel_values": [a, b, c, d],  # "image" modality
-                "image_grid_thw": [e, f, g, h],  # "image" modality
-                "pixel_values_video": [h, i, j],  # "video" modality
-                "video_grid_thw": [k, l, m],  # "video" modality
-            }
+# NOTE: UserDict is for V0 compatibility.
+# V1 should access individual items via `get_item`.
+class MultiModalKwargs(UserDict[str, NestedTensors]):
+    """
+    A dictionary that represents the keyword arguments to
+    :meth:`~torch.nn.Module.forward`.
 
-        - The keyword arguments belonging to the first image are
-          :code:`{"pixel_values": a, "image_grid_thw": e}`.
-        - The keyword arguments belonging to the second video are
-          :code:`{"pixel_values_video": i, "video_grid_thw": l}`.
+    The metadata :code:`items` enables us to obtain the keyword arguments
+    corresponding to each data item in :class:`MultiModalDataItems`, via
+    :meth:`get_num_items` and :meth:`get_item`.
     """
 
     @staticmethod
     def from_hf_inputs(
         hf_inputs: BatchFeature,
         config_by_key: Mapping[str, MultiModalFieldConfig],
-        *,
-        enable_sanity_checks: bool = False,
     ):
         # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
         # We assume that those fields are not used in vLLM
-        items_by_key = {
-            key: config.build_items(key, batch)
-            for key, config in config_by_key.items()
-            if (batch := hf_inputs.get(key)) is not None
-        }
-
-        return MultiModalKwargs.from_items_by_key(
-            items_by_key,
-            enable_sanity_checks=enable_sanity_checks,
-        )
+        elems_by_key = dict[str, list[MultiModalFieldElem]]()
+        keys_by_modality = defaultdict[str, set[str]](set)
+        for key, config in config_by_key.items():
+            batch = hf_inputs.get(key)
+            if batch is not None:
+                elems = config.build_elems(key, batch)
+                if len(elems) > 0:
+                    elems_by_key[key] = elems
+                    keys_by_modality[config.modality].add(key)
+
+        items = list[MultiModalKwargsItem]()
+        for modality, keys in keys_by_modality.items():
+            elems_in_modality = {k: elems_by_key[k] for k in keys}
+            batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
+
+            if len(set(batch_sizes.values())) > 1:
+                raise ValueError(
+                    f"Cannot merge different batch sizes for {modality=}! "
+                    f"Found: {batch_sizes=}")
+
+            batch_size = next(iter(batch_sizes.values()))
+            for item_idx in range(batch_size):
+                elems = [v[item_idx] for v in elems_in_modality.values()]
+                items.append(MultiModalKwargsItem.from_elems(elems))
+
+        return MultiModalKwargs.from_items(items)
 
     @staticmethod
-    def from_items_by_key(
-        items_by_key: Mapping[str, list[MultiModalFieldItem]],
-        *,
-        enable_sanity_checks: bool = False,
-    ) -> "MultiModalKwargs":
+    def from_items(items: list[MultiModalKwargsItem]) -> "MultiModalKwargs":
+        """Construct a new :class:`MultiModalKwargs` from multiple items."""
+        elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
+        for item in items:
+            for key, elem in item.items():
+                elems_by_key[key].append(elem)
+
         data = {
-            key: items[0].field.reduce(items).data
-            for key, items in items_by_key.items() if len(items) > 0
+            key: elems[0].field.reduce(elems).data
+            for key, elems in elems_by_key.items() if len(elems) > 0
         }
 
-        return MultiModalKwargs(data,
-                                items_by_key=items_by_key,
-                                enable_sanity_checks=enable_sanity_checks)
+        return MultiModalKwargs(data, items=items)
 
     def __init__(
         self,
         data: Mapping[str, NestedTensors],
         *,
-        items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
-        enable_sanity_checks: bool = False,
+        items: Optional[Sequence[MultiModalKwargsItem]] = None,
     ) -> None:
         super().__init__(data)
 
-        # Shallow copy to avoid footgun in case a defaultdict is passed in
-        self._items_by_key = dict(items_by_key)
+        items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
+        self._items_by_modality = dict(items_by_modality)
 
-        keys_by_modality = defaultdict[str, set[str]](set)
-        for key, items in items_by_key.items():
-            for item in items:
-                keys_by_modality[item.field.modality].add(key)
-
-        self._keys_by_modality = dict(keys_by_modality)
-
-        if enable_sanity_checks:
-            for modality, keys in keys_by_modality.items():
-                items_in_modality = {k: items_by_key[k] for k in keys}
-                batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
-                batch_size = next(iter(batch_sizes.values()), 0)
-                assert all(bs == batch_size
-                           for bs in batch_sizes.values()), dict(
-                               modality=modality,
-                               batch_sizes=batch_sizes,
-                               items_by_key=items_by_key)
+    @property
+    def modalities(self):
+        return self._items_by_modality.keys()
 
     @staticmethod
     def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
@@ -452,58 +446,38 @@ def as_kwargs(
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        if self._items_by_key != other._items_by_key:
+        if self._items_by_modality != other._items_by_modality:
             return False
 
         ks = self.keys()
         return (ks == other.keys()
                 and all(nested_tensors_equal(self[k], other[k]) for k in ks))
 
-    def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
-        return self._items_by_key[key][item_index]
+    def get_num_items(self, modality: str) -> int:
+        """Get the number of items belonging to a modality."""
+        if not self._items_by_modality:
+            raise RuntimeError(
+                "`get_num_items` is not supported when "
+                "MultiModalKwargs is not initialized with `items`")
 
-    def get_items_by_modality(
-        self,
-        modality: str,
-        item_index: int,
-    ) -> Mapping[str, MultiModalFieldItem]:
+        return len(self._items_by_modality[modality])
+
+    def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
         """
         Get the keyword arguments corresponding to an item identified by
         its modality and index.
         """
-        if modality not in self._keys_by_modality:
-            available_modalities = set(self._keys_by_modality.keys())
+        if not self._items_by_modality:
+            raise RuntimeError(
+                "`get_item` is not supported when "
+                "MultiModalKwargs is not initialized with `items`")
+
+        if modality not in self._items_by_modality:
+            available_modalities = set(self._items_by_modality.keys())
             raise KeyError(f"Modality {modality!r} not found. "
                            f"Available modalities: {available_modalities}")
 
-        keys_to_gather = self._keys_by_modality[modality]
-
-        return {
-            key: self.get_item(key, item_index)
-            for key in keys_to_gather if key in self
-        }
-
-    @staticmethod
-    def from_items_by_modality(
-        items_by_modality: Mapping[str, list[Mapping[str,
-                                                     MultiModalFieldItem]]],
-        *,
-        enable_sanity_checks: bool = False,
-    ) -> "MultiModalKwargs":
-        """
-        Construct a new :class:`MultiModalKwargs` from multiple items returned
-        by :meth:`get_fields_by_modality`.
-        """
-        items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
-        for fields in items_by_modality.values():
-            for field in fields:
-                for k, v in field.items():
-                    items_by_key[k].append(v)
-
-        return MultiModalKwargs.from_items_by_key(
-            items_by_key,
-            enable_sanity_checks=enable_sanity_checks,
-        )
+        return self._items_by_modality[modality][item_index]
 
 
 MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py
index 76475ddda81f4..64cdacfb4c574 100644
--- a/vllm/multimodal/processing.py
+++ b/vllm/multimodal/processing.py
@@ -20,8 +20,8 @@
 from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
 
 from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
-                     MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
-                     PlaceholderRange)
+                     MultiModalInputsV2, MultiModalKwargs,
+                     MultiModalKwargsItem, PlaceholderRange)
 from .parse import MultiModalDataItems, MultiModalDataParser
 
 logger = init_logger(__name__)
@@ -496,8 +496,7 @@ def __init__(self, capacity: int) -> None:
         # DEBUG: Set to None to disable
         self.debug_cache_hit_ratio_steps: Optional[int] = None
 
-        self._cache = LRUCache[str, Mapping[str,
-                                            MultiModalFieldItem]](capacity)
+        self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
 
     def _maybe_log_cache_stats(self) -> None:
         steps = self.debug_cache_hit_ratio_steps
@@ -565,7 +564,7 @@ def get(
         modality: str,
         input_item: object,
         input_kwargs: Mapping[str, object],
-    ) -> Optional[Mapping[str, MultiModalFieldItem]]:
+    ) -> Optional[MultiModalKwargsItem]:
         """
         Get a processed multi-modal item from the cache
         according to its dependencies, including:
@@ -588,7 +587,7 @@ def put(
         modality: str,
         input_item: object,
         input_kwargs: Mapping[str, object],
-        output_kwargs: Mapping[str, MultiModalFieldItem],
+        output_kwargs: MultiModalKwargsItem,
     ) -> None:
         """
         Put a processed multi-modal item into the cache
@@ -784,7 +783,6 @@ def _apply_hf_processor(
         mm_kwargs = MultiModalKwargs.from_hf_inputs(
             processed_data,
             self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
-            enable_sanity_checks=self.enable_sanity_checks,
         )
 
         return prompt_ids, mm_kwargs
@@ -846,7 +844,7 @@ def _cached_apply_hf_processor(
                 hf_processor_mm_kwargs=hf_processor_mm_kwargs,
             )
 
-        mm_maybe_cached_field_items = {
+        mm_maybe_cached_kw_items = {
             modality: [
                 cache.get(model_id, modality, item, hf_processor_mm_kwargs)
                 for item in items
@@ -855,8 +853,9 @@ def _cached_apply_hf_processor(
         }
 
         mm_missing_idxs = {
-            modality: [idx for idx, out in enumerate(fields) if out is None]
-            for modality, fields in mm_maybe_cached_field_items.items()
+            modality:
+            [idx for idx, item in enumerate(kw_items) if item is None]
+            for modality, kw_items in mm_maybe_cached_kw_items.items()
         }
         mm_missing_data = {
             modality: [mm_data_items[modality][idx] for idx in idxs]
@@ -875,14 +874,11 @@ def _cached_apply_hf_processor(
             for modality in mm_missing_data_items
         }
 
-        mm_merged_field_items = dict[str, list[Mapping[str,
-                                                       MultiModalFieldItem]]]()
-        for modality, modal_items_lst in mm_maybe_cached_field_items.items():
-            merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
-
-            for idx, modal_items in enumerate(modal_items_lst):
-                if modal_items is None:
-                    modal_items = mm_missing_kwargs.get_items_by_modality(
+        merged_kw_items = list[MultiModalKwargsItem]()
+        for modality, kw_items in mm_maybe_cached_kw_items.items():
+            for idx, kw_item in enumerate(kw_items):
+                if kw_item is None:
+                    kw_item = mm_missing_kwargs.get_item(
                         modality,
                         mm_missing_next_idx[modality],
                     )
@@ -892,14 +888,12 @@ def _cached_apply_hf_processor(
                         modality,
                         mm_data_items[modality][idx],
                         hf_processor_mm_kwargs,
-                        modal_items,
+                        kw_item,
                     )
 
                     mm_missing_next_idx[modality] += 1
 
-                merged_modal_items_lst.append(modal_items)
-
-            mm_merged_field_items[modality] = merged_modal_items_lst
+                merged_kw_items.append(kw_item)
 
         if self.enable_sanity_checks:
             mm_missing_counts = mm_missing_data_items.get_all_counts()
@@ -909,10 +903,7 @@ def _cached_apply_hf_processor(
                     mm_missing_next_idx=mm_missing_next_idx,
                     mm_missing_counts=mm_missing_counts)
 
-        mm_kwargs = MultiModalKwargs.from_items_by_modality(
-            mm_merged_field_items,
-            enable_sanity_checks=self.enable_sanity_checks,
-        )
+        mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
 
         if self.enable_sanity_checks:
             mm_item_counts = mm_data_items.get_all_counts()
@@ -920,7 +911,7 @@ def _cached_apply_hf_processor(
             for modality, item_count in mm_item_counts.items():
                 for item_idx in range(item_count):
                     try:
-                        mm_kwargs.get_items_by_modality(modality, item_idx)
+                        mm_kwargs.get_item(modality, item_idx)
                     except Exception as e:
                         # Make it easy to set a breakpoint in the debugger
                         raise e
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 5b5a5a61cea7d..14643b0ede5b1 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -113,15 +113,24 @@ def process_inputs(
 
         # For merged preprocessor, mm_data is already mm_inputs
         precomputed_mm_inputs = None
-        if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
-            precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
+        decoder_mm_data = decoder_inputs.multi_modal_data
+        if isinstance(decoder_mm_data, MultiModalKwargs):
+            precomputed_mm_inputs = [
+                MultiModalKwargs.from_items(
+                    [decoder_mm_data.get_item(modality, item_idx)])
+                for modality in decoder_mm_data.modalities
+                for item_idx in range(decoder_mm_data.get_num_items(modality))
+            ]
 
         # Apply MM mapper
         mm_inputs = None
-        if len(decoder_inputs.multi_modal_data) > 0:
+        if len(decoder_mm_data) > 0:
             mm_inputs = self.mm_input_mapper_client.process_inputs(
-                decoder_inputs.multi_modal_data, mm_hashes,
-                decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
+                decoder_mm_data,
+                mm_hashes,
+                decoder_inputs.mm_processor_kwargs,
+                precomputed_mm_inputs,
+            )
 
         return EngineCoreRequest(
             request_id,

From f00efc8bae43ab46655ba8ffd65d70e2238e0b49 Mon Sep 17 00:00:00 2001
From: DarkLight1337 <tlleungac@connect.ust.hk>
Date: Thu, 2 Jan 2025 06:28:04 +0000
Subject: [PATCH 2/3] Simplify the code for getting multiple items

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
---
 vllm/multimodal/inputs.py   | 42 +++++++++++++++++++++----------------
 vllm/v1/engine/processor.py |  5 ++---
 2 files changed, 26 insertions(+), 21 deletions(-)

diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py
index c41cf57e910b7..b0a1104546186 100644
--- a/vllm/multimodal/inputs.py
+++ b/vllm/multimodal/inputs.py
@@ -273,7 +273,7 @@ def build_elems(
         self,
         key: str,
         batch: NestedTensors,
-    ) -> list[MultiModalFieldElem]:
+    ) -> Sequence[MultiModalFieldElem]:
         field = self.field_cls(key=key, modality=self.modality)
         return field.build_elems(batch, **self.field_config)  # type: ignore
 
@@ -285,7 +285,7 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
     """
 
     @staticmethod
-    def from_elems(elems: list[MultiModalFieldElem]) -> "MultiModalKwargsItem":
+    def from_elems(elems: Sequence[MultiModalFieldElem]):
         return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
 
     @property
@@ -304,7 +304,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
 
     The metadata :code:`items` enables us to obtain the keyword arguments
     corresponding to each data item in :class:`MultiModalDataItems`, via
-    :meth:`get_num_items` and :meth:`get_item`.
+    :meth:`get_item` and :meth:`get_items`.
     """
 
     @staticmethod
@@ -314,7 +314,7 @@ def from_hf_inputs(
     ):
         # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
         # We assume that those fields are not used in vLLM
-        elems_by_key = dict[str, list[MultiModalFieldElem]]()
+        elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
         keys_by_modality = defaultdict[str, set[str]](set)
         for key, config in config_by_key.items():
             batch = hf_inputs.get(key)
@@ -342,7 +342,7 @@ def from_hf_inputs(
         return MultiModalKwargs.from_items(items)
 
     @staticmethod
-    def from_items(items: list[MultiModalKwargsItem]) -> "MultiModalKwargs":
+    def from_items(items: Sequence[MultiModalKwargsItem]):
         """Construct a new :class:`MultiModalKwargs` from multiple items."""
         elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
         for item in items:
@@ -453,13 +453,20 @@ def __eq__(self, other: object) -> bool:
         return (ks == other.keys()
                 and all(nested_tensors_equal(self[k], other[k]) for k in ks))
 
-    def get_num_items(self, modality: str) -> int:
-        """Get the number of items belonging to a modality."""
+    def _validate_modality(self, method_name: str, modality: str) -> None:
         if not self._items_by_modality:
             raise RuntimeError(
-                "`get_num_items` is not supported when "
+                f"`{method_name}` is not supported when "
                 "MultiModalKwargs is not initialized with `items`")
 
+        if modality not in self._items_by_modality:
+            available_modalities = set(self._items_by_modality.keys())
+            raise KeyError(f"Modality {modality!r} not found. "
+                           f"Available modalities: {available_modalities}")
+
+    def get_item_count(self, modality: str) -> int:
+        """Get the number of items belonging to a modality."""
+        self._validate_modality("get_item_count", modality)
         return len(self._items_by_modality[modality])
 
     def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
@@ -467,18 +474,17 @@ def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
         Get the keyword arguments corresponding to an item identified by
         its modality and index.
         """
-        if not self._items_by_modality:
-            raise RuntimeError(
-                "`get_item` is not supported when "
-                "MultiModalKwargs is not initialized with `items`")
-
-        if modality not in self._items_by_modality:
-            available_modalities = set(self._items_by_modality.keys())
-            raise KeyError(f"Modality {modality!r} not found. "
-                           f"Available modalities: {available_modalities}")
-
+        self._validate_modality("get_item", modality)
         return self._items_by_modality[modality][item_index]
 
+    def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
+        """
+        Get the keyword arguments corresponding to each item belonging to
+        a modality.
+        """
+        self._validate_modality("get_items", modality)
+        return self._items_by_modality[modality]
+
 
 MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
 """
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 14643b0ede5b1..67810be534bed 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -116,10 +116,9 @@ def process_inputs(
         decoder_mm_data = decoder_inputs.multi_modal_data
         if isinstance(decoder_mm_data, MultiModalKwargs):
             precomputed_mm_inputs = [
-                MultiModalKwargs.from_items(
-                    [decoder_mm_data.get_item(modality, item_idx)])
+                MultiModalKwargs.from_items([item])
                 for modality in decoder_mm_data.modalities
-                for item_idx in range(decoder_mm_data.get_num_items(modality))
+                for item in decoder_mm_data.get_items(modality)
             ]
 
         # Apply MM mapper

From 22ad89f445a20fde1d8593ec9342ea2ab2457439 Mon Sep 17 00:00:00 2001
From: DarkLight1337 <tlleungac@connect.ust.hk>
Date: Thu, 2 Jan 2025 06:39:06 +0000
Subject: [PATCH 3/3] Add comment

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
---
 vllm/v1/engine/processor.py | 4 ++++
 1 file changed, 4 insertions(+)

diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 67810be534bed..905d3d1fc3e1c 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -115,6 +115,10 @@ def process_inputs(
         precomputed_mm_inputs = None
         decoder_mm_data = decoder_inputs.multi_modal_data
         if isinstance(decoder_mm_data, MultiModalKwargs):
+            # The output of merged multi-modal processor (`decoder_mm_data`)
+            # contains the kwargs for all items from all modalities.
+            # This code separates them so that there is one set of kwargs
+            # per item per modality.
             precomputed_mm_inputs = [
                 MultiModalKwargs.from_items([item])
                 for modality in decoder_mm_data.modalities