Skip to content

Commit a462331

Browse files
authored
[Bugfix] Disable moe inplace for torch >= 2.9 (#26497)
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 4069db3 commit a462331

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
1515
TopKWeightAndReduceNoOP,
1616
)
17-
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
17+
from vllm.model_executor.layers.fused_moe.utils import _resize_cache, disable_inplace
1818
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
1919
marlin_make_workspace_new,
2020
marlin_moe_intermediate_size,
@@ -235,7 +235,11 @@ def fused_marlin_moe(
235235
).view(-1, topk, K)
236236

237237
if output is None:
238-
output = hidden_states if inplace else torch.empty_like(hidden_states)
238+
if inplace and not disable_inplace():
239+
output = hidden_states
240+
else:
241+
output = torch.empty_like(hidden_states)
242+
239243
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
240244

241245

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from vllm.model_executor.layers.fused_moe.utils import (
4040
_resize_cache,
4141
activation_without_mul,
42+
disable_inplace,
4243
moe_kernel_quantize_input,
4344
)
4445
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
@@ -1516,7 +1517,7 @@ def torch_vllm_outplace_fused_experts(**kwargs) -> torch.Tensor:
15161517

15171518

15181519
def dispatch_fused_experts_func(inplace: bool) -> Callable[..., torch.Tensor]:
1519-
if inplace:
1520+
if inplace and not disable_inplace():
15201521
return torch_vllm_inplace_fused_experts
15211522
return torch_vllm_outplace_fused_experts
15221523

@@ -1766,7 +1767,10 @@ def fused_experts_impl(
17661767
else:
17671768
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")
17681769

1769-
out_hidden_states = hidden_states if inplace else torch.empty_like(hidden_states)
1770+
if inplace and not disable_inplace():
1771+
out_hidden_states = hidden_states
1772+
else:
1773+
out_hidden_states = torch.empty_like(hidden_states)
17701774

17711775
if ocp_mx_scheme is not None:
17721776
# TODO: On platforms for which `current_platform.supports_mx()` is True

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from vllm.model_executor.layers.fused_moe.utils import (
1414
_resize_cache,
1515
count_expert_num_tokens,
16+
disable_inplace,
1617
)
1718
from vllm.utils import cdiv
1819
from vllm.v1.worker.ubatching import (
@@ -1139,7 +1140,7 @@ def forward(
11391140
- torch.Tensor: The output tensor after applying the MoE layer.
11401141
"""
11411142

1142-
if inplace and self.shared_experts is None:
1143+
if inplace and self.shared_experts is None and not disable_inplace():
11431144
output = hidden_states
11441145
else:
11451146
output = torch.zeros_like(hidden_states)

vllm/model_executor/layers/fused_moe/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
mxfp8_e4m3_quantize,
2424
)
2525
from vllm.triton_utils import tl, triton
26-
from vllm.utils import cdiv
26+
from vllm.utils import cdiv, is_torch_equal_or_newer
2727
from vllm.utils.flashinfer import flashinfer_fp4_quantize
2828

2929

@@ -321,3 +321,10 @@ def _validate_scale_shape(
321321

322322
def activation_without_mul(activation: str) -> str:
323323
return activation + "_no_mul"
324+
325+
326+
# Torch custom ops can't deal with outputs aliasing inputs so we need to
327+
# disable inplace for torch >= 2.9.
328+
# See https://github.com/vllm-project/vllm/issues/26378
329+
def disable_inplace() -> bool:
330+
return is_torch_equal_or_newer("2.9")

0 commit comments

Comments
 (0)