Skip to content

Commit 42efe60

Browse files
Isotr0pyywang96
andauthored
[MM][Bugfix] Replace PatchEmbed's conv3d to linear layer (#27418)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Roger Wang <hey@rogerw.io>
1 parent 88d3141 commit 42efe60

File tree

6 files changed

+97
-42
lines changed

6 files changed

+97
-42
lines changed

vllm/model_executor/models/glm4_1v.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ColumnParallelLinear,
6161
MergedColumnParallelLinear,
6262
QKVParallelLinear,
63+
ReplicatedLinear,
6364
RowParallelLinear,
6465
)
6566
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -98,7 +99,11 @@
9899
init_vllm_registered_model,
99100
maybe_prefix,
100101
)
101-
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
102+
from .vision import (
103+
conv3d_to_linear_weight,
104+
get_vit_attn_backend,
105+
run_dp_sharded_mrope_vision_model,
106+
)
102107

103108
logger = init_logger(__name__)
104109

@@ -478,18 +483,15 @@ def __init__(
478483
self.hidden_size = hidden_size
479484

480485
kernel_size = (temporal_patch_size, patch_size, patch_size)
481-
self.proj = nn.Conv3d(
482-
in_channels,
486+
self.proj = ReplicatedLinear(
487+
in_channels * math.prod(kernel_size),
483488
hidden_size,
484-
kernel_size=kernel_size,
485-
stride=kernel_size,
486489
bias=True,
490+
return_bias=False,
487491
)
488492

489493
def forward(self, x: torch.Tensor) -> torch.Tensor:
490-
L, C = x.shape
491-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
492-
x = self.proj(x).view(L, self.hidden_size)
494+
x = self.proj(x)
493495
return x
494496

495497

@@ -887,6 +889,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
887889
loaded_params: set[str] = set()
888890

889891
for name, loaded_weight in weights:
892+
if name.endswith("patch_embed.proj.weight"):
893+
loaded_weight = conv3d_to_linear_weight(loaded_weight)
894+
890895
for param_name, weight_name, shard_id in stacked_params_mapping:
891896
if weight_name not in name:
892897
continue

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
# limitations under the License.
2727
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
2828

29+
import math
2930
from collections.abc import Callable, Iterable, Mapping, Sequence
3031
from functools import lru_cache, partial
3132
from typing import Annotated, Any, Literal, TypeAlias
@@ -56,6 +57,7 @@
5657
ColumnParallelLinear,
5758
MergedColumnParallelLinear,
5859
QKVParallelLinear,
60+
ReplicatedLinear,
5961
RowParallelLinear,
6062
)
6163
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -98,7 +100,11 @@
98100
init_vllm_registered_model,
99101
maybe_prefix,
100102
)
101-
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
103+
from .vision import (
104+
conv3d_to_linear_weight,
105+
get_vit_attn_backend,
106+
run_dp_sharded_mrope_vision_model,
107+
)
102108

103109
logger = init_logger(__name__)
104110

@@ -532,18 +538,15 @@ def __init__(
532538
self.hidden_size = hidden_size
533539

534540
kernel_size = (temporal_patch_size, patch_size, patch_size)
535-
self.proj = nn.Conv3d(
536-
in_channels,
541+
self.proj = ReplicatedLinear(
542+
in_channels * math.prod(kernel_size),
537543
hidden_size,
538-
kernel_size=kernel_size,
539-
stride=kernel_size,
540544
bias=False,
545+
return_bias=False,
541546
)
542547

543548
def forward(self, x: torch.Tensor) -> torch.Tensor:
544-
L, C = x.shape
545-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
546-
x = self.proj(x).view(L, self.hidden_size)
549+
x = self.proj(x)
547550
return x
548551

549552

@@ -950,6 +953,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
950953
loaded_params: set[str] = set()
951954

952955
for name, loaded_weight in weights:
956+
if name.endswith("patch_embed.proj.weight"):
957+
loaded_weight = conv3d_to_linear_weight(loaded_weight)
958+
953959
for param_name, weight_name, shard_id in stacked_params_mapping:
954960
if weight_name not in name:
955961
continue

vllm/model_executor/models/qwen2_vl.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# limitations under the License.
2626
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
2727

28+
import math
2829
from collections.abc import Callable, Iterable, Mapping, Sequence
2930
from functools import partial
3031
from typing import Annotated, Any, Literal, TypeAlias
@@ -53,7 +54,11 @@
5354
from vllm.distributed import utils as dist_utils
5455
from vllm.logger import init_logger
5556
from vllm.model_executor.layers.activation import QuickGELU
56-
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
57+
from vllm.model_executor.layers.linear import (
58+
ColumnParallelLinear,
59+
ReplicatedLinear,
60+
RowParallelLinear,
61+
)
5762
from vllm.model_executor.layers.quantization import QuantizationConfig
5863
from vllm.model_executor.layers.rotary_embedding.common import (
5964
dispatch_rotary_emb_function,
@@ -100,7 +105,11 @@
100105
init_vllm_registered_model,
101106
maybe_prefix,
102107
)
103-
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
108+
from .vision import (
109+
conv3d_to_linear_weight,
110+
get_vit_attn_backend,
111+
run_dp_sharded_mrope_vision_model,
112+
)
104113

105114
logger = init_logger(__name__)
106115

@@ -561,18 +570,15 @@ def __init__(
561570
self.embed_dim = embed_dim
562571

563572
kernel_size = (temporal_patch_size, patch_size, patch_size)
564-
self.proj = nn.Conv3d(
565-
in_channels,
573+
self.proj = ReplicatedLinear(
574+
in_channels * math.prod(kernel_size),
566575
embed_dim,
567-
kernel_size=kernel_size,
568-
stride=kernel_size,
569576
bias=False,
577+
return_bias=False,
570578
)
571579

572580
def forward(self, x: torch.Tensor) -> torch.Tensor:
573-
L, C = x.shape
574-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
575-
x = self.proj(x).view(L, self.embed_dim)
581+
x = self.proj(x)
576582
return x
577583

578584

@@ -835,6 +841,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
835841
loaded_params: set[str] = set()
836842

837843
for name, loaded_weight in weights:
844+
if name.endswith("patch_embed.proj.weight"):
845+
loaded_weight = conv3d_to_linear_weight(loaded_weight)
846+
838847
for param_name, weight_name, shard_id in stacked_params_mapping:
839848
if weight_name not in name:
840849
continue

vllm/model_executor/models/qwen3_omni_moe_thinker.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# limitations under the License.
2323
"""Inference-only Qwen3-Omni-Moe model (thinker part)."""
2424

25+
import math
2526
from collections.abc import Callable, Iterable, Mapping, Sequence
2627
from functools import partial
2728
from typing import Any
@@ -53,7 +54,11 @@
5354
from vllm.distributed import get_pp_group
5455
from vllm.logger import init_logger
5556
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
56-
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
57+
from vllm.model_executor.layers.linear import (
58+
ColumnParallelLinear,
59+
ReplicatedLinear,
60+
RowParallelLinear,
61+
)
5762
from vllm.model_executor.layers.logits_processor import LogitsProcessor
5863
from vllm.model_executor.layers.quantization import QuantizationConfig
5964
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -98,7 +103,11 @@
98103
_merge_multimodal_embeddings,
99104
maybe_prefix,
100105
)
101-
from .vision import get_llm_pos_ids_for_vision, get_vit_attn_backend
106+
from .vision import (
107+
conv3d_to_linear_weight,
108+
get_llm_pos_ids_for_vision,
109+
get_vit_attn_backend,
110+
)
102111

103112
try:
104113
import flash_attn
@@ -131,18 +140,16 @@ def __init__(
131140
self.hidden_size = hidden_size
132141

133142
kernel_size = (temporal_patch_size, patch_size, patch_size)
134-
self.proj = nn.Conv3d(
135-
in_channels,
143+
self.proj = ReplicatedLinear(
144+
in_channels * math.prod(kernel_size),
136145
hidden_size,
137-
kernel_size=kernel_size,
138-
stride=kernel_size,
139146
bias=True,
147+
return_bias=False,
140148
)
141149

142150
def forward(self, x: torch.Tensor) -> torch.Tensor:
143151
L, C = x.shape
144-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
145-
x = self.proj(x).view(L, self.hidden_size)
152+
x = self.proj(x)
146153
return x
147154

148155

@@ -559,6 +566,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
559566
loaded_params: set[str] = set()
560567

561568
for name, loaded_weight in weights:
569+
if name.endswith("patch_embed.proj.weight"):
570+
loaded_weight = conv3d_to_linear_weight(loaded_weight)
571+
562572
for param_name, weight_name, shard_id in stacked_params_mapping:
563573
if weight_name not in name:
564574
continue

vllm/model_executor/models/qwen3_vl.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# limitations under the License.
2525
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
2626

27+
import math
2728
from collections.abc import Callable, Iterable, Mapping, Sequence
2829
from functools import partial
2930
from itertools import islice
@@ -56,7 +57,11 @@
5657
from vllm.distributed import get_pp_group
5758
from vllm.logger import init_logger
5859
from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
59-
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
60+
from vllm.model_executor.layers.linear import (
61+
ColumnParallelLinear,
62+
ReplicatedLinear,
63+
RowParallelLinear,
64+
)
6065
from vllm.model_executor.layers.logits_processor import LogitsProcessor
6166
from vllm.model_executor.layers.quantization import QuantizationConfig
6267
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -107,7 +112,11 @@
107112
_merge_multimodal_embeddings,
108113
maybe_prefix,
109114
)
110-
from .vision import get_vit_attn_backend, run_dp_sharded_mrope_vision_model
115+
from .vision import (
116+
conv3d_to_linear_weight,
117+
get_vit_attn_backend,
118+
run_dp_sharded_mrope_vision_model,
119+
)
111120

112121
logger = init_logger(__name__)
113122

@@ -129,18 +138,15 @@ def __init__(
129138
self.hidden_size = hidden_size
130139

131140
kernel_size = (temporal_patch_size, patch_size, patch_size)
132-
self.proj = nn.Conv3d(
133-
in_channels,
141+
self.proj = ReplicatedLinear(
142+
in_channels * math.prod(kernel_size),
134143
hidden_size,
135-
kernel_size=kernel_size,
136-
stride=kernel_size,
137144
bias=True,
145+
return_bias=False,
138146
)
139147

140148
def forward(self, x: torch.Tensor) -> torch.Tensor:
141-
L, C = x.shape
142-
x = x.view(L, -1, self.temporal_patch_size, self.patch_size, self.patch_size)
143-
x = self.proj(x).view(L, self.hidden_size)
149+
x = self.proj(x)
144150
return x
145151

146152

@@ -576,6 +582,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
576582
loaded_params: set[str] = set()
577583

578584
for name, loaded_weight in weights:
585+
if name.endswith("patch_embed.proj.weight"):
586+
loaded_weight = conv3d_to_linear_weight(loaded_weight)
587+
579588
for param_name, weight_name, shard_id in stacked_params_mapping:
580589
if weight_name not in name:
581590
continue

vllm/model_executor/models/vision.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,3 +544,19 @@ def get_llm_pos_ids_for_vision(
544544
llm_pos_ids_list.append(_llm_pos_ids + start_idx)
545545
llm_pos_ids = torch.cat(llm_pos_ids_list, dim=1)
546546
return llm_pos_ids
547+
548+
549+
# Due to a performance regression with Conv3D in PyTorch2.9, we reshape
550+
# Conv3D weights to Linear weights for better performance.
551+
# See: https://github.com/vllm-project/vllm/issues/27406
552+
# and https://github.com/pytorch/pytorch/issues/166122
553+
# FIXME(Isotr0py): Revert the PR introduces this workaround
554+
# (https://github.com/vllm-project/vllm/pull/27418),
555+
# once the performance issue is resolved in PyTorch.
556+
def conv3d_to_linear_weight(conv3d_weight: torch.Tensor) -> torch.Tensor:
557+
"""
558+
Reshape Conv3D weight to Linear weight. Only work when kernel_size==stride.
559+
"""
560+
out_channels, in_channels, kt, kh, kw = conv3d_weight.shape
561+
linear_weight = conv3d_weight.reshape(out_channels, in_channels * kt * kh * kw)
562+
return linear_weight

0 commit comments

Comments
 (0)