Skip to content

Commit 8a72a9c

Browse files
tlrmchlsmthbnellnm
authored andcommitted
Clean up diff
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Bill Nell <bnell@redhat.com>
1 parent 549a9fe commit 8a72a9c

File tree

3 files changed

+36
-55
lines changed

3 files changed

+36
-55
lines changed

vllm/cuda_graph_utils.py

Whitespace-only changes.

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -535,16 +535,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
535535
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
536536
B.shape[1], META['BLOCK_SIZE_N']), )
537537

538-
if use_dg:
539-
assert use_fp8_w8a8
540-
# Note: we never apply the topk_weights here since it requires
541-
# unpermuting and resizing the output. This goes against the
542-
# existing interface as the `mul_routed_weight` argument is
543-
# ignored. The weights are applied in _moe_unpermute.
544-
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
545-
(A, A_scale), (B, B_scale), C, expert_ids)
546-
547-
elif (use_int8_w8a16 or use_int4_w4a16) and \
538+
if (use_int8_w8a16 or use_int4_w4a16) and \
548539
block_shape is not None and block_shape[1] > 0:
549540
assert B_scale is not None and B_scale.ndim == 3
550541
assert B_zp is None or B_zp.ndim == 3
@@ -848,7 +839,6 @@ def try_get_optimal_moe_config(
848839
M: int,
849840
is_marlin: bool = False,
850841
block_shape: Optional[List[int]] = None,
851-
use_deep_gemm: bool = False,
852842
):
853843
from vllm.model_executor.layers.fused_moe import get_config
854844
override_config = get_config()
@@ -871,11 +861,6 @@ def try_get_optimal_moe_config(
871861
# Else use the default config
872862
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
873863
is_marlin, block_shape)
874-
875-
# Enforce DeepGemm M blocking no matter what the config says.
876-
if use_deep_gemm:
877-
config['BLOCK_SIZE_M'] = dg.get_m_alignment_for_contiguous_layout()
878-
879864
return config
880865

881866

@@ -1048,14 +1033,13 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
10481033
w2_zp: Optional[torch.Tensor] = None,
10491034
a1_scale: Optional[torch.Tensor] = None,
10501035
a2_scale: Optional[torch.Tensor] = None,
1051-
block_shape: Optional[List[int]] = None,
1052-
allow_deep_gemm: bool = False) -> None:
1036+
block_shape: Optional[List[int]] = None) -> None:
10531037
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
10541038
activation, apply_router_weight_on_input, use_fp8_w8a8,
10551039
use_int8_w8a8, use_int8_w8a16, use_int4_w4a16,
10561040
per_channel_quant, global_num_experts, expert_map,
10571041
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
1058-
block_shape, allow_deep_gemm)
1042+
block_shape)
10591043

10601044

10611045
def inplace_fused_experts_fake(
@@ -1489,7 +1473,6 @@ def fused_moe(
14891473
a1_scale: Optional[torch.Tensor] = None,
14901474
a2_scale: Optional[torch.Tensor] = None,
14911475
block_shape: Optional[List[int]] = None,
1492-
allow_deep_gemm: bool = True,
14931476
) -> torch.Tensor:
14941477
"""
14951478
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1523,8 +1506,8 @@ def fused_moe(
15231506
Defaults to False.
15241507
- global_num_experts (int): The total number of experts in the global
15251508
expert space.
1526-
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
1527-
from the global expert space to the local expert space of the expert
1509+
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
1510+
from the global expert space to the local expert space of the expert
15281511
parallel shard.
15291512
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
15301513
w1.

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from abc import abstractmethod
4+
from dataclasses import dataclass
45
from enum import Enum
56
from typing import Callable, List, Optional, Tuple
6-
from dataclasses import dataclass
77

8+
import pplx_kernels as pplx
89
import torch
910
import torch.nn.functional as F
1011
from torch.nn.parameter import UninitializedParameter
1112

12-
import pplx_kernels as pplx
13-
1413
import vllm.envs as envs
1514
from vllm.config import get_current_vllm_config
1615
from vllm.distributed import (get_dp_group, get_ep_group,
@@ -47,6 +46,7 @@
4746

4847
MOE_DP_CHUNK_SIZE = 256
4948

49+
5050
# Adapted from pplx-kernels tests/all_to_all_utils.py
5151
@dataclass
5252
class MoEConfig:
@@ -64,6 +64,7 @@ class MoEConfig:
6464
out_dtype: torch.dtype = torch.bfloat16
6565
block_size: int = 128
6666

67+
6768
class FusedMoeWeightScaleSupported(Enum):
6869
TENSOR = "tensor"
6970
CHANNEL = "channel"
@@ -100,26 +101,14 @@ def apply(
100101
) -> torch.Tensor:
101102
raise NotImplementedError
102103

104+
105+
#TODO: Every change in this class is a broken hack!!
103106
@CustomOp.register("unquantized_fused_moe")
104107
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
105108
"""MoE method without quantization."""
106-
def __init__(self, moe: MoEConfig):
107-
self.all_to_all = pplx.AllToAll(
108-
max_num_tokens=MOE_DP_CHUNK_SIZE // moe.dp_size,
109-
num_experts=moe.num_experts,
110-
experts_per_token=moe.experts_per_token,
111-
rank=moe.ep_rank,
112-
world_size=moe.ep_size,
113-
dp_size=moe.dp_size,
114-
hidden_dim=moe.hidden_dim,
115-
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
116-
hidden_dim_scale_bytes=0,
117-
)
118-
119109

120-
def __init__(self):
110+
def __init__(self, moe: MoEConfig):
121111
super().__init__()
122-
123112
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
124113
if self.rocm_aiter_moe_enabled:
125114
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
@@ -903,7 +892,7 @@ def forward(self, hidden_states: torch.Tensor,
903892
self.layer_name)
904893

905894
def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
906-
full_router_logits: torch.Tensor):
895+
full_router_logits: torch.Tensor):
907896
max_tokens_across_dp = get_forward_context(
908897
).dp_metadata.max_tokens_across_dp
909898
cu_tokens_across_dp_cpu = get_forward_context(
@@ -919,21 +908,23 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
919908

920909
num_tokens_remaining_across_dp = num_tokens_across_dp
921910
chunk_start = 0
922-
chunk_end = min(moe_dp_chunk_size_per_rank, full_hidden_states.shape[0])
911+
chunk_end = min(moe_dp_chunk_size_per_rank,
912+
full_hidden_states.shape[0])
923913
full_final_hidden_states = torch.empty_like(full_hidden_states)
924914

925915
for _ in range(0, max_tokens_across_dp, moe_dp_chunk_size_per_rank):
926-
hidden_states = full_hidden_states[chunk_start:chunk_end,:]
927-
router_logits = full_router_logits[chunk_start:chunk_end,:]
916+
hidden_states = full_hidden_states[chunk_start:chunk_end, :]
917+
router_logits = full_router_logits[chunk_start:chunk_end, :]
928918

929919
cu_tokens_across_dp_this_iter = torch.cumsum(
930-
num_tokens_remaining_across_dp.clamp(max=moe_dp_chunk_size_per_rank),
920+
num_tokens_remaining_across_dp.clamp(
921+
max=moe_dp_chunk_size_per_rank),
931922
dim=0)
932923

933-
hidden_states = self.naive_multicast(hidden_states,
934-
cu_tokens_across_dp_this_iter)
935-
router_logits = self.naive_multicast(router_logits,
936-
cu_tokens_across_dp_this_iter)
924+
hidden_states = self.naive_multicast(
925+
hidden_states, cu_tokens_across_dp_this_iter)
926+
router_logits = self.naive_multicast(
927+
router_logits, cu_tokens_across_dp_this_iter)
937928

938929
# Matrix multiply.
939930
final_hidden_states = self.quant_method.apply(
@@ -954,7 +945,8 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
954945
)
955946

956947
if self.dp_size > 1:
957-
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[self.dp_rank-1]
948+
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_this_iter[
949+
self.dp_rank - 1]
958950
end = cu_tokens_across_dp_this_iter[self.dp_rank]
959951

960952
all_hidden_states = get_dp_group().all_reduce(
@@ -963,20 +955,26 @@ def forward_impl_chunked(self, full_hidden_states: torch.Tensor,
963955

964956
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
965957
# Default set to False. (May have to add shared expert outputs.)
966-
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
958+
final_hidden_states = tensor_model_parallel_all_reduce(
959+
final_hidden_states)
967960

968-
full_final_hidden_states[chunk_start:chunk_end, :].copy_(final_hidden_states)
961+
full_final_hidden_states[chunk_start:chunk_end, :].copy_(
962+
final_hidden_states)
969963

970964
# Update bounds
971-
num_tokens_remaining_across_dp = torch.clamp(num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank, min=0)
965+
num_tokens_remaining_across_dp = torch.clamp(
966+
num_tokens_remaining_across_dp - moe_dp_chunk_size_per_rank,
967+
min=0)
968+
972969
def update_chunk_bound(x: int):
973-
return min(x + moe_dp_chunk_size_per_rank, full_hidden_states.shape[0])
970+
return min(x + moe_dp_chunk_size_per_rank,
971+
full_hidden_states.shape[0])
972+
974973
chunk_start = update_chunk_bound(chunk_start)
975974
chunk_end = update_chunk_bound(chunk_end)
976975

977976
return full_final_hidden_states
978977

979-
980978
def forward_impl(self, hidden_states: torch.Tensor,
981979
router_logits: torch.Tensor):
982980
assert self.quant_method is not None

0 commit comments

Comments
 (0)