Skip to content

Commit 885d9ce

Browse files
committed
linting
1 parent 6f33557 commit 885d9ce

File tree

6 files changed

+25
-11
lines changed

6 files changed

+25
-11
lines changed

vllm_ascend/ascend_forward_context.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,7 @@ def set_ascend_forward_context(
9595
forward_context.fused_moe_state = fused_moe_state
9696
forward_context.in_profile_run = in_profile_run
9797

98-
from vllm_ascend.ops.moe.token_dispatcher import \
99-
get_token_dispatcher
98+
from vllm_ascend.ops.moe.token_dispatcher import get_token_dispatcher
10099
dispatcher_name = get_dispatcher_name(ep_size, with_prefill)
101100
dispatcher = get_token_dispatcher(dispatcher_name)
102101
forward_context.token_dispatcher = dispatcher

vllm_ascend/ops/common_fused_moe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,11 @@
2727
from vllm.model_executor.layers.fused_moe.layer import (
2828
FusedMoE, UnquantizedFusedMoEMethod)
2929
from vllm_ascend.ascend_config import get_ascend_config
30-
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
31-
AlltoAllCommImpl,
32-
MC2CommImpl)
3330
from vllm_ascend.distributed.parallel_state import get_mc2_group
3431
from vllm_ascend.ops.moe.experts_selector import select_experts
35-
from vllm_ascend.ops.moe.token_dispatcher import \
36-
setup_token_dispatchers
32+
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
33+
AlltoAllCommImpl, MC2CommImpl)
34+
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
3735
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
3836

3937
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__

vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
1010
from vllm.forward_context import get_forward_context
1111
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
12+
1213
from vllm_ascend.distributed.communication_op import \
1314
data_parallel_reduce_scatter
1415

@@ -19,11 +20,17 @@ def __init__(self, moe_config: Optional[FusedMoEConfig]):
1920
self.moe_config = moe_config
2021

2122
@abstractmethod
22-
def prepare(self):
23+
def prepare(self,
24+
hidden_states: torch.Tensor,
25+
router_logits: torch.Tensor,
26+
enable_shared_expert_dp: bool = False,
27+
rm_router_logits: bool = False,
28+
replace_allreduce: bool = False,
29+
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
2330
raise NotImplementedError("Prepare not implemented.")
2431

25-
@abstractmethod
26-
def finalize(self):
32+
def finalize(self, hidden_states: torch.Tensor,
33+
reduce_results: bool) -> torch.Tensor:
2734
raise NotImplementedError("Combine function not implemented.")
2835

2936

@@ -91,6 +98,8 @@ def finalize(self, hidden_states: torch.Tensor,
9198
9299
Also, unpad the hidden states if needed.
93100
"""
101+
assert self.moe_config.tp_group is not None, "tp_group cannot be None."
102+
94103
if not (self.enable_shared_expert_dp or self.replace_all_reduce):
95104
if self.tp_size > 1:
96105
dist.all_gather(list(self.split_hidden_states), hidden_states,
@@ -155,6 +164,8 @@ def finalize(self, hidden_states: torch.Tensor,
155164
156165
Also, unpad the hidden states if needed.
157166
"""
167+
assert self.moe_config.tp_group is not None, "tp_group cannot be None."
168+
158169
if not (self.enable_shared_expert_dp or self.replace_all_reduce):
159170
if self.tp_size > 1:
160171
dist.all_gather(list(self.split_hidden_states), hidden_states,
@@ -180,9 +191,12 @@ def prepare(self,
180191
replace_all_reduce: bool = False,
181192
gate=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
182193
"""When DP size > 1, pad the hidden states and router logits for communication."""
194+
assert self.moe_config.dp_size is not None, "dp_size cannot be None."
195+
assert self.moe_config.dp_group is not None, "dp_group cannot be None."
196+
183197
self.rm_router_logits = rm_router_logits
184198
self.enable_shared_expert_dp = enable_shared_expert_dp
185-
199+
186200
if self.moe_config.dp_size > 1:
187201
forward_context = get_forward_context()
188202
max_tokens_across_dp = forward_context.max_tokens_across_dp

vllm_ascend/ops/moe/moe_comm_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from vllm.forward_context import get_forward_context
66
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
7+
78
from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import (
89
FusedMoEPrepareAndFinalizeWithAll2All,
910
FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2)

vllm_ascend/ops/moe/moe_mlp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch_npu
2121
from torch.nn.functional import pad
2222
from vllm.forward_context import get_forward_context
23+
2324
from vllm_ascend.ascend_forward_context import FusedMoEState
2425
from vllm_ascend.utils import dispose_tensor, is_310p
2526

vllm_ascend/ops/moe/token_dispatcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch
2828
import torch_npu
2929
from vllm.distributed.parallel_state import get_ep_group
30+
3031
from vllm_ascend.distributed.parallel_state import get_mc2_group
3132
from vllm_ascend.distributed.tensor_parallel import \
3233
gather_from_sequence_parallel_region

0 commit comments

Comments
 (0)