Skip to content

Commit 98f30b8

Browse files
authored
[Model] Fix Skywork R1V mlp (#26673)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent 3cd3666 commit 98f30b8

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/model_executor/models/skyworkr1v.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
691691
prefix=maybe_prefix(prefix, "language_model"),
692692
)
693693

694-
self.mlp1 = self._init_mlp1(config)
694+
self.mlp1 = self._init_mlp1(
695+
config, quant_config, prefix=maybe_prefix(prefix, "mlp1")
696+
)
695697

696698
self.img_context_token_id = None
697699
self.visual_token_mask = None
@@ -738,7 +740,12 @@ def _init_vision_model(
738740
else:
739741
return InternVisionPatchModel(config.vision_config)
740742

741-
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
743+
def _init_mlp1(
744+
self,
745+
config: PretrainedConfig,
746+
quant_config: QuantizationConfig,
747+
prefix: str = "",
748+
) -> nn.Module:
742749
vit_hidden_size = config.vision_config.hidden_size
743750
llm_hidden_size = config.text_config.hidden_size
744751

@@ -748,9 +755,17 @@ def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
748755
vit_hidden_size * int(1 / self.downsample_ratio) ** 2,
749756
llm_hidden_size,
750757
return_bias=False,
758+
quant_config=quant_config,
759+
prefix=f"{prefix}.1",
751760
),
752761
nn.GELU(),
753-
ReplicatedLinear(llm_hidden_size, llm_hidden_size, return_bias=False),
762+
ReplicatedLinear(
763+
llm_hidden_size,
764+
llm_hidden_size,
765+
return_bias=False,
766+
quant_config=quant_config,
767+
prefix=f"{prefix}.3",
768+
),
754769
)
755770

756771
def pixel_shuffle(self, x, scale_factor=0.5):

0 commit comments

Comments
 (0)