Skip to content

Commit

Permalink
[bugfix] fix chatglm dummy_data_for_glmv (#9955)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <youkaichao@gmail.com>
  • Loading branch information
youkaichao authored Nov 2, 2024
1 parent d6459b4 commit 74b529c
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions vllm/model_executor/models/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
token_inputs)
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -31,8 +31,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs)
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalInputs
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
Expand Down Expand Up @@ -117,16 +116,15 @@ def get_max_glmv_image_tokens(ctx: InputContext):
raise NotImplementedError(msg)


def dummy_data_for_glmv(
ctx: InputContext, seq_len: int, mm_counts: Mapping[str, int]
) -> Tuple[SequenceData, Optional[MultiModalDataDict]]:
def dummy_data_for_glmv(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]) -> DummyData:
hf_config = ctx.get_hf_config(ChatGLMConfig)
vision_config = getattr(hf_config, 'vision_config', None)

if vision_config is None:
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)
seq_data = SequenceData(token_ids)
return seq_data, None
return DummyData(seq_data, None)
elif isinstance(vision_config, dict):
image_size = vision_config["image_size"]
image_placeholder_length = calculate_image_placeholder(vision_config)
Expand All @@ -141,7 +139,7 @@ def dummy_data_for_glmv(
"image": Image.new("RGB", (image_size, image_size), color=0)
}

return seq_data, mm_data
return DummyData(seq_data, mm_data)

msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
Expand Down

0 comments on commit 74b529c

Please sign in to comment.