Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
infer_state.b_ready_cache_len = torch.zeros_like(input=infer_state.b_seq_len)

infer_state.multimodal_params = model_input.multimodal_params
infer_state.image_start_locs = model_input.image_start_locs
infer_state.image_token_lens = model_input.image_token_lens
infer_state.image_start_token_ids = model_input.image_start_token_ids

infer_state.mem_manager = self.mem_manager
infer_state.req_manager = self.req_manager
Expand Down
3 changes: 3 additions & 0 deletions lightllm/common/basemodel/batch_objs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ModelInput:
mem_indexes: torch.Tensor = None
is_prefill: bool = False
b_ready_cache_len: torch.Tensor = None
image_start_locs: torch.Tensor = None
image_token_lens: torch.Tensor = None
image_start_token_ids: torch.Tensor = None
multimodal_params: list = field(default_factory=list)

# cpu 变量
Expand Down
5 changes: 4 additions & 1 deletion lightllm/models/gemma3/gemma3_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@


class Gemma3VisionModel:
def __init__(self):
def __init__(self, kvargs):
self.weight_dir = kvargs["weight_dir"]
self.load_model(self.weight_dir)
self.cuda()
pass

def load_model(self, weight_dir):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight

from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
from lightllm.models.vit.model import VisionTransformer
from lightllm.utils.envs_utils import get_env_start_args


# add key: language_model.xxx -> xxx
Expand All @@ -15,9 +17,36 @@ def rename_weight_keys(weights):
weights[k[len(prefix) :]] = weights[k]


def build_visual_model(args, data_type: torch.dtype):
if args.disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": args.model_dir,
"data_type": data_type,
"quant_type": args.vit_quant_type,
"quant_cfg": args.vit_quant_cfg,
"max_batch_size": args.visual_infer_batch_size,
}
return VisionTransformer(kvargs=kvargs)
return None


class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
self.visual_model = build_visual_model(get_env_start_args(), data_type)
return

def load_hf_weights(self, weights):
rename_weight_keys(weights)
super().load_hf_weights(weights)


class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
self.visual_model = build_visual_model(get_env_start_args(), data_type)
return

def load_hf_weights(self, weights):
Expand All @@ -29,6 +58,8 @@ def load_hf_weights(self, weights):
class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
self.visual_model = build_visual_model(get_env_start_args(), data_type)
return

def load_hf_weights(self, weights):
Expand All @@ -40,6 +71,8 @@ def load_hf_weights(self, weights):
class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
self.visual_model = build_visual_model(get_env_start_args(), data_type)
return

def load_hf_weights(self, weights):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def load_hf_weights(self, weights):
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

return

def verify_load(self):
Expand Down
5 changes: 4 additions & 1 deletion lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@


class LlavaVisionModel:
def __init__(self):
def __init__(self, kvargs):
self.weight_dir = kvargs["weight_dir"]
self.load_model(self.weight_dir)
self.cuda()
pass

def load_model(self, weight_dir):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import numpy as np
from lightllm.utils.envs_utils import get_env_start_args
from transformers.configuration_utils import PretrainedConfig
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer


def build_visual_model(args, data_type: torch.dtype):
if args.disable_extra_process_for_multimodal:
kvargs = {
"weight_dir": args.model_dir,
"data_type": args.data_type,
"quant_type": args.vit_quant_type,
"quant_cfg": args.vit_quant_cfg,
"max_batch_size": args.visual_infer_batch_size,
}
model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"])
return Qwen2_5VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type)
return None


class Qwen2_5VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
def __init__(self, data_type, network_config, mode):
super().__init__(data_type, network_config, mode)
self.visual_model = build_visual_model(get_env_start_args(), data_type)
return
Loading