From cd5bd84585f0d0589bfd939e28217b8eda74d34d Mon Sep 17 00:00:00 2001 From: B-201 Date: Sun, 17 Nov 2024 21:14:22 +0800 Subject: [PATCH 1/2] separate chatglm Signed-off-by: B-201 --- vllm/model_executor/models/chatglm.py | 97 +++++++++++++++++++++------ 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 70e9b607b0642..3a6c1cd9012d8 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -29,6 +29,7 @@ ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import MultiModalData, MultiModalKwargs @@ -573,25 +574,8 @@ def forward( return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) -@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) -@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) -class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, - SupportsMultiModal): - packed_modules_mapping = { - "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] - } - # LoRA specific attributes - supported_lora_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", - ] - embedding_modules = {} - embedding_padding_modules = [] +class ChatGLMBaseModel(nn.Module, SupportsLoRA, SupportsPP, + SupportsMultiModal): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -686,3 +670,78 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, combined_weight) + + +class ChatGLM(ChatGLMBaseModel): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + ] + + embedding_modules = {} + embedding_padding_modules = [] + + +class ChatGLMV(ChatGLMBaseModel): + packed_modules_mapping = { + "query_key_value": ["query_key_value"], + "dense_h_to_4h": ["dense_h_to_4h"] + } + # LoRA specific attributes + supported_lora_modules = [ + "query_key_value", + "dense", + "dense_h_to_4h", + "dense_4h_to_h", + # vision + "fc1", + "fc2", + "gate_proj", + "linear_proj" + ] + + embedding_modules = {} + embedding_padding_modules = [] + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="transformer.encoder", + connector="transformer.vision.linear_proj", + tower_model="transformer.vision.transformer") + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(mm_input_mapper_for_glmv) +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_glmv_image_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_glmv) +@INPUT_REGISTRY.register_input_processor(input_processor_for_glmv) +class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, + SupportsMultiModal): + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + def __new__( + cls, + vllm_config: VllmConfig, + prefix: str = "", + ) -> None: + config = vllm_config.model_config.hf_config + # Initialize VL + if hasattr(config, "visual"): + return ChatGLM(vllm_config=vllm_config, prefix=prefix) + # Initialize LLM + else: + return ChatGLMV(vllm_config=vllm_config, prefix=prefix) From 3d76524fcab29ae80c9b34f94d29955f9fe102a0 Mon Sep 17 00:00:00 2001 From: B-201 Date: Mon, 18 Nov 2024 18:45:53 +0800 Subject: [PATCH 2/2] support lora Signed-off-by: B-201 --- vllm/model_executor/models/chatglm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 45f75a0aa1090..625e31bb0d368 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -674,6 +674,8 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, combined_weight) + loaded_params.add(combined_name) + return loaded_params class ChatGLM(ChatGLMBaseModel): @@ -696,7 +698,8 @@ class ChatGLM(ChatGLMBaseModel): class ChatGLMV(ChatGLMBaseModel): packed_modules_mapping = { "query_key_value": ["query_key_value"], - "dense_h_to_4h": ["dense_h_to_4h"] + "dense_h_to_4h": ["dense_h_to_4h"], + "merged_proj": ["gate_proj", "dense_h_to_4h"] } # LoRA specific attributes supported_lora_modules = [ @@ -707,7 +710,7 @@ class ChatGLMV(ChatGLMBaseModel): # vision "fc1", "fc2", - "gate_proj", + "merged_proj", "linear_proj" ]