11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33from collections .abc import Iterable , Mapping , Sequence
4- from typing import Annotated , Optional , Union
4+ from typing import Annotated , Literal , Optional , Union
55
66import torch
77import torch .nn as nn
3838# yapf: enable
3939from .interfaces import MultiModalEmbeddings , SupportsMultiModal , SupportsQuant
4040from .llama import LlamaDecoderLayer , LlamaMLP , LlamaModel
41- from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
42- is_pp_missing_parameter , maybe_prefix )
41+ from .utils import (AutoWeightsLoader , WeightsMapper , is_pp_missing_parameter ,
42+ maybe_prefix )
4343
4444
4545class AriaImagePixelInputs (TensorSchema ):
@@ -52,6 +52,8 @@ class AriaImagePixelInputs(TensorSchema):
5252 - w: Width of each image
5353 """
5454
55+ type : Literal ["pixel_values" ]
56+
5557 pixel_values : Annotated [
5658 torch .Tensor ,
5759 TensorShape ("bn" , 3 , "h" , "w" ),
@@ -485,6 +487,8 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
485487 This model combines a vision tower, a multi-modal projector, and a language
486488 model to perform tasks that involve both image and text inputs.
487489 """
490+ merge_by_field_config = True
491+
488492 hf_to_vllm_mapper = WeightsMapper (
489493 orig_to_new_prefix = {
490494 # mapping for new names in checkpoint saved after transformers v4.52
@@ -551,12 +555,15 @@ def _parse_and_validate_image_input(
551555 return None
552556
553557 return AriaImagePixelInputs (
554- pixel_values = flatten_bn (pixel_values , concat = True ),
555- pixel_mask = flatten_bn (pixel_mask , concat = True ),
558+ type = "pixel_values" ,
559+ pixel_values = pixel_values ,
560+ pixel_mask = pixel_mask ,
556561 )
557562
558563 def _create_patch_attention_mask (
559- self , pixel_mask : Optional [torch .Tensor ]) -> torch .Tensor :
564+ self ,
565+ pixel_mask : Optional [torch .Tensor ],
566+ ) -> Optional [torch .Tensor ]:
560567 if pixel_mask is None :
561568 return None
562569
0 commit comments