Skip to content

Commit e1e0fd7

Browse files
[TPU] Avoid Triton Import (#15589)
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
1 parent df8d3d1 commit e1e0fd7

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
from vllm.forward_context import ForwardContext, get_forward_context
1717
from vllm.logger import init_logger
1818
from vllm.model_executor.custom_op import CustomOp
19-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
20-
is_rocm_aiter_moe_enabled, shuffle_weights)
2119
from vllm.model_executor.layers.quantization.base_config import (
2220
QuantizationConfig, QuantizeMethodBase)
2321
from vllm.model_executor.utils import set_weight_attrs
@@ -119,7 +117,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
119117
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
120118
layer.w2_weight.data),
121119
requires_grad=False)
122-
120+
# Lazy import to avoid importing triton.
121+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
122+
is_rocm_aiter_moe_enabled, shuffle_weights)
123123
if is_rocm_aiter_moe_enabled():
124124
# reshaping weights is required for aiter moe kernel.
125125
shuffled_w13, shuffled_w2 = shuffle_weights(

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from vllm.logger import init_logger
1414
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
1515
FusedMoeWeightScaleSupported)
16-
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
17-
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
18-
is_rocm_aiter_moe_enabled, shuffle_weights)
1916
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
2017
UnquantizedLinearMethod)
2118
from vllm.model_executor.layers.quantization.base_config import (
@@ -532,6 +529,11 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int,
532529
layer.w2_input_scale = None
533530

534531
def process_weights_after_loading(self, layer: Module) -> None:
532+
# Lazy import to avoid importing triton too early.
533+
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
534+
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
535+
is_rocm_aiter_moe_enabled, shuffle_weights)
536+
535537
# TODO (rob): refactor block quant into separate class.
536538
if self.block_quant:
537539
assert self.quant_config.activation_scheme == "dynamic"

0 commit comments

Comments
 (0)