Skip to content

Commit de74947

Browse files
committed
clean up object types and initialization
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent a03d8e7 commit de74947

File tree

1 file changed

+27
-28
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+27
-28
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ def __init__(
13821382
"Only softmax scoring function is supported for non-grouped topk."
13831383
)
13841384

1385-
moe = FusedMoEConfig(
1385+
self.moe_config: FusedMoEConfig = FusedMoEConfig(
13861386
num_experts=self.global_num_experts,
13871387
experts_per_token=top_k,
13881388
hidden_dim=hidden_size,
@@ -1393,24 +1393,26 @@ def __init__(
13931393
has_bias=has_bias,
13941394
is_act_and_mul=is_act_and_mul,
13951395
)
1396-
self.moe_config = moe
1396+
13971397
self.moe_quant_config: FusedMoEQuantConfig | None = None
13981398
self.quant_config = quant_config
13991399

1400+
def _get_quant_method() -> FusedMoEMethodBase:
1401+
"""
1402+
Helper method to ensure self.quant_method is never None and
1403+
of the proper type.
1404+
"""
1405+
quant_method = None
1406+
if self.quant_config is not None:
1407+
quant_method = self.quant_config.get_quant_method(self, prefix)
1408+
if quant_method is None:
1409+
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
1410+
assert isinstance(quant_method, FusedMoEMethodBase)
1411+
return quant_method
1412+
14001413
# Note: get_quant_method will look at the layer's local_num_experts
14011414
# for heuristic purposes, so it must be initialized first.
1402-
quant_method: QuantizeMethodBase | None = None
1403-
quant_method = (
1404-
UnquantizedFusedMoEMethod(moe)
1405-
if quant_config is None
1406-
else quant_config.get_quant_method(self, prefix)
1407-
)
1408-
if quant_method is None:
1409-
quant_method = UnquantizedFusedMoEMethod(moe)
1410-
1411-
assert quant_method is not None
1412-
assert isinstance(quant_method, FusedMoEMethodBase)
1413-
self.quant_method = quant_method
1415+
self.quant_method: FusedMoEMethodBase = _get_quant_method()
14141416

14151417
if not self.moe_config.is_act_and_mul:
14161418
# Avoid circular import
@@ -1430,20 +1432,17 @@ def __init__(
14301432
"is_act_and_mul=False is supported only for CUDA for now"
14311433
)
14321434

1433-
if self.enable_eplb:
1434-
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
1435-
1436-
if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)):
1437-
# TODO: Add support for additional quantization methods.
1438-
# The implementation for other quantization methods does not
1439-
# contain essential differences, but the current quant API
1440-
# design causes duplicated work when extending to new
1441-
# quantization methods, so I'm leaving it for now.
1442-
# If you plan to add support for more quantization methods,
1443-
# please refer to the implementation in `Fp8MoEMethod`.
1444-
raise NotImplementedError(
1445-
"EPLB is only supported for FP8 quantization for now."
1446-
)
1435+
if self.enable_eplb and not self.quant_method.supports_eplb:
1436+
# TODO: Add support for additional quantization methods.
1437+
# The implementation for other quantization methods does not
1438+
# contain essential differences, but the current quant API
1439+
# design causes duplicated work when extending to new
1440+
# quantization methods, so I'm leaving it for now.
1441+
# If you plan to add support for more quantization methods,
1442+
# please refer to the implementation in `Fp8MoEMethod`.
1443+
raise NotImplementedError(
1444+
"EPLB is only supported for FP8 quantization for now."
1445+
)
14471446

14481447
moe_quant_params = {
14491448
"num_experts": self.local_num_experts,

0 commit comments

Comments
 (0)