3131
3232from fastdeploy .config import FDConfig
3333from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
34+ from fastdeploy .model_executor .forward_meta import ForwardMeta
3435from fastdeploy .model_executor .graph_optimization .decorator import (
3536 cuda_graph_buffers ,
3637 support_graph_optimization ,
3738)
3839from fastdeploy .model_executor .layers .embeddings import VocabParallelEmbedding
40+ from fastdeploy .model_executor .layers .image_op import (
41+ text_image_gather_scatter ,
42+ text_image_index_out ,
43+ )
3944from fastdeploy .model_executor .layers .linear import ReplicatedLinear
4045from fastdeploy .model_executor .layers .lm_head import ParallelLMHead
4146from fastdeploy .model_executor .layers .moe .moe import FusedMoE
4550 Ernie4_5_MLP ,
4651)
4752from fastdeploy .model_executor .models .model_base import ModelForCasualLM
48- from fastdeploy .platforms import current_platform
49-
50- if current_platform .is_cuda ():
51- from fastdeploy .model_executor .ops .gpu import (
52- text_image_gather_scatter ,
53- text_image_index_out ,
54- )
55- elif current_platform .is_xpu ():
56- from fastdeploy .model_executor .ops .xpu import (
57- text_image_gather_scatter ,
58- text_image_index_out ,
59- )
60-
61- from fastdeploy .model_executor .forward_meta import ForwardMeta
6253
6354
6455class Ernie4_5_VLMLP (Ernie4_5_MLP ):
@@ -75,7 +66,6 @@ class VLMoEMeta:
7566 text_input : paddle .Tensor
7667 text_index : paddle .Tensor
7768 image_index : paddle .Tensor
78- image_mask : paddle .Tensor
7969 token_type_ids : paddle .Tensor
8070 image_token_num : paddle .Tensor
8171
@@ -86,7 +76,6 @@ def __str__(self):
8676 f" text_input: { self .text_input } , pointer: { self .text_input .data_ptr ()} \n "
8777 f" text_index: { self .text_index } , pointer: { self .text_index .data_ptr ()} \n "
8878 f" image_index: { self .image_index } , pointer: { self .image_index .data_ptr ()} \n "
89- f" image_mask: { self .image_mask } , pointer: { self .image_mask .data_ptr ()} \n "
9079 f" token_type_ids: { self .token_type_ids } , pointer: { self .token_type_ids .data_ptr ()} \n \n "
9180 f")"
9281 )
@@ -419,11 +408,6 @@ def forward(
419408 "dtype" : "model_config.dtype" ,
420409 "value" : 1 ,
421410 },
422- "image_mask" : {
423- "shape" : ["parallel_config.max_model_len" , "model_config.hidden_size" ],
424- "dtype" : "bool" ,
425- "value" : False ,
426- },
427411 "text_index" : {
428412 "shape" : ["parallel_config.max_model_len" ],
429413 "dtype" : "int32" ,
0 commit comments