Skip to content

Commit 4d8d68f

Browse files
committed
clean up object types and initialization
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 61ff67d commit 4d8d68f

File tree

1 file changed

+38
-34
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+38
-34
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1381,7 +1381,7 @@ def __init__(
13811381
"Only softmax scoring function is supported for non-grouped topk."
13821382
)
13831383

1384-
moe = FusedMoEConfig(
1384+
self.moe_config: FusedMoEConfig = FusedMoEConfig(
13851385
num_experts=self.global_num_experts,
13861386
experts_per_token=top_k,
13871387
hidden_dim=hidden_size,
@@ -1392,24 +1392,26 @@ def __init__(
13921392
has_bias=has_bias,
13931393
is_act_and_mul=is_act_and_mul,
13941394
)
1395-
self.moe_config = moe
1395+
13961396
self.moe_quant_config: FusedMoEQuantConfig | None = None
13971397
self.quant_config = quant_config
13981398

1399+
def _get_quant_method() -> FusedMoEMethodBase:
1400+
"""
1401+
Helper method to ensure self.quant_method is never None and
1402+
of the proper type.
1403+
"""
1404+
quant_method = None
1405+
if self.quant_config is not None:
1406+
quant_method = self.quant_config.get_quant_method(self, prefix)
1407+
if quant_method is None:
1408+
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
1409+
assert isinstance(quant_method, FusedMoEMethodBase)
1410+
return quant_method
1411+
13991412
# Note: get_quant_method will look at the layer's local_num_experts
14001413
# for heuristic purposes, so it must be initialized first.
1401-
quant_method: QuantizeMethodBase | None = None
1402-
quant_method = (
1403-
UnquantizedFusedMoEMethod(moe)
1404-
if quant_config is None
1405-
else quant_config.get_quant_method(self, prefix)
1406-
)
1407-
if quant_method is None:
1408-
quant_method = UnquantizedFusedMoEMethod(moe)
1409-
1410-
assert quant_method is not None
1411-
assert isinstance(quant_method, FusedMoEMethodBase)
1412-
self.quant_method = quant_method
1414+
self.quant_method: FusedMoEMethodBase = _get_quant_method()
14131415

14141416
if not self.moe_config.is_act_and_mul:
14151417
# Avoid circular import
@@ -1429,20 +1431,17 @@ def __init__(
14291431
"is_act_and_mul=False is supported only for CUDA for now"
14301432
)
14311433

1432-
if self.enable_eplb:
1433-
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
1434-
1435-
if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)):
1436-
# TODO: Add support for additional quantization methods.
1437-
# The implementation for other quantization methods does not
1438-
# contain essential differences, but the current quant API
1439-
# design causes duplicated work when extending to new
1440-
# quantization methods, so I'm leaving it for now.
1441-
# If you plan to add support for more quantization methods,
1442-
# please refer to the implementation in `Fp8MoEMethod`.
1443-
raise NotImplementedError(
1444-
"EPLB is only supported for FP8 quantization for now."
1445-
)
1434+
if self.enable_eplb and not self.quant_method.supports_eplb:
1435+
# TODO: Add support for additional quantization methods.
1436+
# The implementation for other quantization methods does not
1437+
# contain essential differences, but the current quant API
1438+
# design causes duplicated work when extending to new
1439+
# quantization methods, so I'm leaving it for now.
1440+
# If you plan to add support for more quantization methods,
1441+
# please refer to the implementation in `Fp8MoEMethod`.
1442+
raise NotImplementedError(
1443+
"EPLB is only supported for FP8 quantization for now."
1444+
)
14461445

14471446
moe_quant_params = {
14481447
"num_experts": self.local_num_experts,
@@ -1471,19 +1470,24 @@ def __init__(
14711470
logits_shape: tuple[int, ...]
14721471

14731472
# Note here we use `num_experts` which is logical expert count
1473+
max_num_tokens = self.moe_config.max_num_tokens
14741474
if vllm_config.parallel_config.enable_dbo:
1475-
states_shape = (2, moe.max_num_tokens, self.hidden_size)
1476-
logits_shape = (2, moe.max_num_tokens, num_experts)
1475+
states_shape = (2, max_num_tokens, self.hidden_size)
1476+
logits_shape = (2, max_num_tokens, num_experts)
14771477
else:
1478-
states_shape = (moe.max_num_tokens, self.hidden_size)
1479-
logits_shape = (moe.max_num_tokens, num_experts)
1478+
states_shape = (max_num_tokens, self.hidden_size)
1479+
logits_shape = (max_num_tokens, num_experts)
14801480

14811481
self.batched_hidden_states = torch.zeros(
1482-
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
1482+
states_shape,
1483+
dtype=self.moe_config.in_dtype,
1484+
device=torch.cuda.current_device(),
14831485
)
14841486

14851487
self.batched_router_logits = torch.zeros(
1486-
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
1488+
logits_shape,
1489+
dtype=self.moe_config.in_dtype,
1490+
device=torch.cuda.current_device(),
14871491
)
14881492

14891493
# Note: init_prepare_finalize should only be called by

0 commit comments

Comments
 (0)