Skip to content

Commit 2c03071

Browse files
committed
clean up object types and initialization
Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent abb455b commit 2c03071

File tree

1 file changed

+39
-35
lines changed
  • vllm/model_executor/layers/fused_moe

1 file changed

+39
-35
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1363,7 +1363,7 @@ def __init__(
13631363
"Only softmax scoring function is supported for non-grouped topk."
13641364
)
13651365

1366-
moe = FusedMoEConfig(
1366+
self.moe_config: FusedMoEConfig = FusedMoEConfig(
13671367
num_experts=self.global_num_experts,
13681368
experts_per_token=top_k,
13691369
hidden_dim=hidden_size,
@@ -1373,39 +1373,38 @@ def __init__(
13731373
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
13741374
has_bias=has_bias,
13751375
)
1376-
self.moe_config = moe
1376+
13771377
self.moe_quant_config: FusedMoEQuantConfig | None = None
13781378
self.quant_config = quant_config
13791379

1380+
def _get_quant_method() -> FusedMoEMethodBase:
1381+
"""
1382+
Helper method to ensure self.quant_method is never None and
1383+
of the proper type.
1384+
"""
1385+
quant_method = None
1386+
if self.quant_config is not None:
1387+
quant_method = self.quant_config.get_quant_method(self, prefix)
1388+
if quant_method is None:
1389+
quant_method = UnquantizedFusedMoEMethod(self.moe_config)
1390+
assert isinstance(quant_method, FusedMoEMethodBase)
1391+
return quant_method
1392+
13801393
# Note: get_quant_method will look at the layer's local_num_experts
13811394
# for heuristic purposes, so it must be initialized first.
1382-
quant_method: QuantizeMethodBase | None = None
1383-
quant_method = (
1384-
UnquantizedFusedMoEMethod(moe)
1385-
if quant_config is None
1386-
else quant_config.get_quant_method(self, prefix)
1387-
)
1388-
if quant_method is None:
1389-
quant_method = UnquantizedFusedMoEMethod(moe)
1390-
1391-
assert quant_method is not None
1392-
assert isinstance(quant_method, FusedMoEMethodBase)
1393-
self.quant_method = quant_method
1394-
1395-
if self.enable_eplb:
1396-
from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod
1397-
1398-
if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)):
1399-
# TODO: Add support for additional quantization methods.
1400-
# The implementation for other quantization methods does not
1401-
# contain essential differences, but the current quant API
1402-
# design causes duplicated work when extending to new
1403-
# quantization methods, so I'm leaving it for now.
1404-
# If you plan to add support for more quantization methods,
1405-
# please refer to the implementation in `Fp8MoEMethod`.
1406-
raise NotImplementedError(
1407-
"EPLB is only supported for FP8 quantization for now."
1408-
)
1395+
self.quant_method: FusedMoEMethodBase = _get_quant_method()
1396+
1397+
if self.enable_eplb and not self.quant_method.supports_eplb:
1398+
# TODO: Add support for additional quantization methods.
1399+
# The implementation for other quantization methods does not
1400+
# contain essential differences, but the current quant API
1401+
# design causes duplicated work when extending to new
1402+
# quantization methods, so I'm leaving it for now.
1403+
# If you plan to add support for more quantization methods,
1404+
# please refer to the implementation in `Fp8MoEMethod`.
1405+
raise NotImplementedError(
1406+
"EPLB is only supported for FP8 quantization for now."
1407+
)
14091408

14101409
moe_quant_params = {
14111410
"num_experts": self.local_num_experts,
@@ -1433,19 +1432,24 @@ def __init__(
14331432
logits_shape: tuple[int, ...]
14341433

14351434
# Note here we use `num_experts` which is logical expert count
1435+
max_num_tokens = self.moe_config.max_num_tokens
14361436
if vllm_config.parallel_config.enable_dbo:
1437-
states_shape = (2, moe.max_num_tokens, self.hidden_size)
1438-
logits_shape = (2, moe.max_num_tokens, num_experts)
1437+
states_shape = (2, max_num_tokens, self.hidden_size)
1438+
logits_shape = (2, max_num_tokens, num_experts)
14391439
else:
1440-
states_shape = (moe.max_num_tokens, self.hidden_size)
1441-
logits_shape = (moe.max_num_tokens, num_experts)
1440+
states_shape = (max_num_tokens, self.hidden_size)
1441+
logits_shape = (max_num_tokens, num_experts)
14421442

14431443
self.batched_hidden_states = torch.zeros(
1444-
states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
1444+
states_shape,
1445+
dtype=self.moe_config.in_dtype,
1446+
device=torch.cuda.current_device(),
14451447
)
14461448

14471449
self.batched_router_logits = torch.zeros(
1448-
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
1450+
logits_shape,
1451+
dtype=self.moe_config.in_dtype,
1452+
device=torch.cuda.current_device(),
14491453
)
14501454

14511455
# Note: init_prepare_finalize should only be called by

0 commit comments

Comments
 (0)