Skip to content

Commit 1d5b260

Browse files
authored
fix crash introduced by upstream PR 25613 and PR23991 (#259)
vllm-project/vllm#23991 vllm-project/vllm#25613 --------- Signed-off-by: Chendi Xue <Chendi.Xue@intel.com>
1 parent 60808d7 commit 1d5b260

File tree

5 files changed

+68
-42
lines changed

5 files changed

+68
-42
lines changed

vllm_gaudi/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def register():
99

1010
def register_ops():
1111
"""Register custom ops for the HPU platform."""
12+
HpuPlatform.patch_for_pt27()
1213
import vllm_gaudi.v1.sample.hpu_rejection_sampler # noqa: F401
1314
import vllm_gaudi.distributed.kv_transfer.kv_connector.v1.hpu_nixl_connector # noqa: F401
1415
import vllm_gaudi.ops.hpu_fused_moe # noqa: F401

vllm_gaudi/ops/hpu_compressed_tensors.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -226,16 +226,17 @@ def apply(
226226
input_shape = x.shape
227227
x = x.view(-1, x.shape[-1])
228228
if use_grouped_topk or custom_routing_function is not None:
229-
topk_weights, topk_ids = FusedMoE.select_experts(hidden_states=x,
230-
router_logits=router_logits,
231-
use_grouped_topk=use_grouped_topk,
232-
top_k=top_k,
233-
renormalize=renormalize,
234-
topk_group=topk_group,
235-
num_expert_group=num_expert_group,
236-
custom_routing_function=custom_routing_function,
237-
scoring_func=scoring_func,
238-
e_score_correction_bias=e_score_correction_bias)
229+
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
230+
hidden_states=x,
231+
router_logits=router_logits,
232+
use_grouped_topk=use_grouped_topk,
233+
top_k=top_k,
234+
renormalize=renormalize,
235+
topk_group=topk_group,
236+
num_expert_group=num_expert_group,
237+
custom_routing_function=custom_routing_function,
238+
scoring_func=scoring_func,
239+
e_score_correction_bias=e_score_correction_bias)
239240
else:
240241
import torch.nn.functional as F
241242
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)
@@ -663,18 +664,19 @@ def apply(
663664
x = x.view(-1, x.shape[-1])
664665

665666
if use_grouped_topk or custom_routing_function is not None:
666-
topk_weights, topk_ids = FusedMoE.select_experts(hidden_states=x,
667-
router_logits=router_logits,
668-
use_grouped_topk=use_grouped_topk,
669-
top_k=top_k,
670-
renormalize=renormalize,
671-
topk_group=topk_group,
672-
num_expert_group=num_expert_group,
673-
custom_routing_function=custom_routing_function,
674-
scoring_func=scoring_func,
675-
routed_scaling_factor=routed_scaling_factor,
676-
e_score_correction_bias=e_score_correction_bias,
677-
indices_type=self.topk_indices_dtype)
667+
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
668+
hidden_states=x,
669+
router_logits=router_logits,
670+
use_grouped_topk=use_grouped_topk,
671+
top_k=top_k,
672+
renormalize=renormalize,
673+
topk_group=topk_group,
674+
num_expert_group=num_expert_group,
675+
custom_routing_function=custom_routing_function,
676+
scoring_func=scoring_func,
677+
routed_scaling_factor=routed_scaling_factor,
678+
e_score_correction_bias=e_score_correction_bias,
679+
indices_type=self.topk_indices_dtype)
678680
else:
679681
import torch.nn.functional as F
680682
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)

vllm_gaudi/ops/hpu_fp8.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,17 @@ def apply(
126126
input_shape = x.shape
127127
x = x.view(-1, x.shape[-1])
128128
if use_grouped_topk or custom_routing_function is not None:
129-
topk_weights, topk_ids = FusedMoE.select_experts(hidden_states=x,
130-
router_logits=router_logits,
131-
use_grouped_topk=use_grouped_topk,
132-
top_k=top_k,
133-
renormalize=renormalize,
134-
topk_group=topk_group,
135-
num_expert_group=num_expert_group,
136-
custom_routing_function=custom_routing_function,
137-
scoring_func=scoring_func,
138-
e_score_correction_bias=e_score_correction_bias)
129+
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
130+
hidden_states=x,
131+
router_logits=router_logits,
132+
use_grouped_topk=use_grouped_topk,
133+
top_k=top_k,
134+
renormalize=renormalize,
135+
topk_group=topk_group,
136+
num_expert_group=num_expert_group,
137+
custom_routing_function=custom_routing_function,
138+
scoring_func=scoring_func,
139+
e_score_correction_bias=e_score_correction_bias)
139140
else:
140141
import torch.nn.functional as F
141142
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)

vllm_gaudi/ops/hpu_fused_moe.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,17 @@ def forward_oot(
5353
input_shape = x.shape
5454
x = x.view(-1, x.shape[-1])
5555
if use_grouped_topk or custom_routing_function is not None:
56-
topk_weights, topk_ids = FusedMoE.select_experts(hidden_states=x,
57-
router_logits=router_logits,
58-
use_grouped_topk=use_grouped_topk,
59-
top_k=top_k,
60-
renormalize=renormalize,
61-
topk_group=topk_group,
62-
num_expert_group=num_expert_group,
63-
custom_routing_function=custom_routing_function,
64-
scoring_func=scoring_func,
65-
e_score_correction_bias=e_score_correction_bias)
56+
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
57+
hidden_states=x,
58+
router_logits=router_logits,
59+
use_grouped_topk=use_grouped_topk,
60+
top_k=top_k,
61+
renormalize=renormalize,
62+
topk_group=topk_group,
63+
num_expert_group=num_expert_group,
64+
custom_routing_function=custom_routing_function,
65+
scoring_func=scoring_func,
66+
e_score_correction_bias=e_score_correction_bias)
6667
else:
6768
import torch.nn.functional as F
6869
topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32)

vllm_gaudi/platform.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,24 @@ def _synced_weight_loader(param, *args, **kwargs):
195195
return out
196196

197197
return _synced_weight_loader
198+
199+
@classmethod
200+
def patch_for_pt27(cls) -> None:
201+
202+
from vllm.utils import is_torch_equal_or_newer
203+
if is_torch_equal_or_newer("2.8.0"):
204+
return
205+
206+
from vllm.model_executor import BasevLLMParameter
207+
parent_class = BasevLLMParameter.__mro__[1]
208+
parent_torch_function = getattr(parent_class, "__torch_function__", None)
209+
210+
def torch_function(origin_cls, func, types, args=(), kwargs=None):
211+
if kwargs is None:
212+
kwargs = {}
213+
if parent_torch_function is None:
214+
return NotImplemented
215+
return parent_torch_function(func, types, args, kwargs)
216+
217+
BasevLLMParameter.__torch_function__ = classmethod(torch_function)
218+
return

0 commit comments

Comments
 (0)