Skip to content

Commit eba4b00

Browse files
committed
0807-fix-qwen2vl
1 parent 96a0afb commit eba4b00

File tree

12 files changed

+456
-404
lines changed

12 files changed

+456
-404
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.utils.envs_utils import get_env_start_args
4+
from transformers.configuration_utils import PretrainedConfig
5+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
6+
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5VLTransformer
7+
8+
9+
def build_visual_model(args, data_type: torch.dtype):
10+
if args.disable_extra_process_for_multimodal:
11+
kvargs = {
12+
"weight_dir": args.model_dir,
13+
"data_type": args.data_type,
14+
"quant_type": args.vit_quant_type,
15+
"quant_cfg": args.vit_quant_cfg,
16+
"max_batch_size": args.visual_infer_batch_size,
17+
}
18+
model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"])
19+
return Qwen2_5VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type)
20+
return None
21+
22+
23+
class Qwen2_5VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
24+
def __init__(self, data_type, network_config, mode):
25+
super().__init__(data_type, network_config, mode)
26+
self.visual_model = build_visual_model(get_env_start_args(), data_type)
27+
return

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.nn import LayerNorm
1717
from transformers.activations import ACT2FN
1818
import math
19-
from lightllm.models.qwen2_vl.vision_process import get_image, Qwen2VLImageProcessor
19+
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
2020
from transformers import AutoProcessor
2121
from safetensors import safe_open
2222
from transformers.utils import TensorType
@@ -212,9 +212,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
212212
return x
213213

214214

215-
class Qwen2_5_VisionTransformerPretrainedModel(nn.Module):
215+
class Qwen2_5VLTransformer(nn.Module):
216216
def __init__(
217217
self,
218+
weight_dir,
218219
depth=32,
219220
hidden_size=3584,
220221
hidden_act="silu",
@@ -278,6 +279,11 @@ def __init__(
278279

279280
self.gradient_checkpointing = False
280281

282+
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
283+
with open(processor_config_path, "r") as f:
284+
processor_config_dict = json.load(f)
285+
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
286+
281287
self.device = self.get_device()
282288
self.dtype = self.get_dtype()
283289

@@ -416,12 +422,27 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.
416422

417423
return hidden_states
418424

419-
def load_model(self, weight_dir):
425+
def load_image(self, img: List[ImageItem]):
426+
pixel_values = None
427+
if isinstance(img, ImageItem):
428+
image_data = read_shm(get_shm_name_data(img.uuid))
429+
image_data = Image.open(BytesIO(image_data))
430+
image_data = resize_image(image_data)
431+
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
432+
pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
433+
image_grid_thw = image_inputs["image_grid_thw"]
434+
elif isinstance(img, dict):
435+
image_data = read_shm(get_shm_name_data(img["uuid"]))
436+
image_data = Image.open(BytesIO(image_data))
437+
image_data = resize_image(image_data)
438+
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
439+
pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
440+
image_grid_thw = image_inputs["image_grid_thw"]
441+
else:
442+
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
443+
return pixel_values.to(dtype=self.get_dtype()), image_grid_thw
420444

421-
processor_config_path = os.path.join(weight_dir, "preprocessor_config.json")
422-
with open(processor_config_path, "r") as f:
423-
processor_config_dict = json.load(f)
424-
self.processor = Qwen2VLImageProcessor(**processor_config_dict)
445+
def load_model(self, weight_dir):
425446

426447
bin_weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")]
427448
if bin_weight_files:
@@ -455,7 +476,7 @@ def encode(self, images: List[ImageItem]):
455476
uuids.append(img.uuid)
456477
image_data = read_shm(get_shm_name_data(img.uuid))
457478
image_data = Image.open(BytesIO(image_data))
458-
image_data = get_image(image_data)
479+
image_data = resize_image(image_data)
459480
image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt")
460481
pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16)
461482
image_grid_thw = image_inputs["image_grid_thw"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.utils.envs_utils import get_env_start_args
4+
from transformers.configuration_utils import PretrainedConfig
5+
from lightllm.models.qwen2.layer_weights.pre_and_post_layer_weight import Qwen2PreAndPostLayerWeight
6+
from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VLTransformer
7+
8+
9+
def build_visual_model(args, data_type: torch.dtype):
10+
if args.disable_extra_process_for_multimodal:
11+
kvargs = {
12+
"weight_dir": args.model_dir,
13+
"data_type": args.data_type,
14+
"quant_type": args.vit_quant_type,
15+
"quant_cfg": args.vit_quant_cfg,
16+
"max_batch_size": args.visual_infer_batch_size,
17+
}
18+
model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"])
19+
return Qwen2VLTransformer(kvargs=kvargs, **model_cfg["vision_config"]).eval().to(dtype=data_type)
20+
return None
21+
22+
23+
class Qwen2VLPreAndPostLayerWeight(Qwen2PreAndPostLayerWeight):
24+
def __init__(self, data_type, network_config, mode):
25+
super().__init__(data_type, network_config, mode)
26+
self.visual_model = build_visual_model(get_env_start_args(), data_type)
27+
return

lightllm/models/qwen2_vl/model.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from lightllm.common.build_utils import repair_config
1717
from lightllm.models.registry import ModelRegistry
1818
from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
19+
from lightllm.models.qwen2_vl.layer_weights.pre_and_post_layer_weight import Qwen2VLPreAndPostLayerWeight
20+
from lightllm.models.qwen2_5_vl.layer_weights.pre_and_post_layer_weight import Qwen2_5VLPreAndPostLayerWeight
1921
from lightllm.models.qwen2_vl.layer_infer.transformer_layer_infer import Qwen2VLTransformerLayerInfer
2022

2123
import torch
@@ -93,12 +95,44 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None, **kwargs):
9395
return input_ids
9496

9597

96-
@ModelRegistry(["qwen2_vl", "qwen2_5_vl"], is_multimodal=True)
98+
@ModelRegistry(["qwen2_vl"], is_multimodal=True)
9799
class Qwen2VLTpPartModel(Qwen2TpPartModel):
98100

99101
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
100102
transformer_layer_infer_class = Qwen2VLTransformerLayerInfer
101103

104+
pre_and_post_weight_class = Qwen2VLPreAndPostLayerWeight
105+
106+
infer_state_class = Qwen2VLInferStateInfo
107+
108+
def __init__(self, kvargs):
109+
super().__init__(kvargs)
110+
return
111+
112+
def _init_inferstate_cls(self):
113+
if get_env_start_args().enable_fa3:
114+
self.infer_state_class = Qwen2VLFlashAttentionStateInfo
115+
116+
def _init_config(self):
117+
with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
118+
self.config = json.load(json_file)
119+
# rename keys
120+
repair_config(self.config, same_names=["num_attention_heads", "n_head"])
121+
repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
122+
repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
123+
if self.finetune_config:
124+
self.config["vocab_size"] = self.finetune_config.vocab_size
125+
return
126+
127+
128+
@ModelRegistry(["qwen2_5_vl"], is_multimodal=True)
129+
class Qwen2_5VLTpPartModel(Qwen2TpPartModel):
130+
131+
pre_layer_infer_class = LlamaMultimodalPreLayerInfer
132+
transformer_layer_infer_class = Qwen2VLTransformerLayerInfer
133+
134+
pre_and_post_weight_class = Qwen2_5VLPreAndPostLayerWeight
135+
102136
infer_state_class = Qwen2VLInferStateInfo
103137

104138
def __init__(self, kvargs):

0 commit comments

Comments
 (0)